Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] spec decode handle None entries in topk args in create_sequence_group_output #7232

Merged
merged 2 commits into from
Aug 22, 2024

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Aug 6, 2024

Currently, requesting logprobs for generated tokens when using speculative decoding is broken. Sending a request for logprobs causes the server to crash with:

ERROR 08-06 23:49:35 logs.py:91] generate{input=[b'<|user|>\\nThe future of ai is'] prefix_id= adapter_id= input_chars=[28] params=stopping { max_new_tokens: 16 } response { generated_tokens: true token_logprobs: true } decoding { repetition_penalty: 1.0 }: '>=' not supported between instances of 'NoneType' and 'int'
ERROR 08-06 23:49:35 grpc_server.py:145] Generate failed
ERROR 08-06 23:49:35 grpc_server.py:145] Traceback (most recent call last):
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm_tgis_adapter/grpc/grpc_server.py", line 162, in func_with_log
ERROR 08-06 23:49:35 grpc_server.py:145]     return await func(*args, **kwargs)
ERROR 08-06 23:49:35 grpc_server.py:145]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm_tgis_adapter/grpc/grpc_server.py", line 275, in Generate
ERROR 08-06 23:49:35 grpc_server.py:145]     async for i, res in result_generator:
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/utils.py", line 346, in consumer
ERROR 08-06 23:49:35 grpc_server.py:145]     raise e
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/utils.py", line 337, in consumer
ERROR 08-06 23:49:35 grpc_server.py:145]     raise item
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/utils.py", line 312, in producer
ERROR 08-06 23:49:35 grpc_server.py:145]     async for item in iterator:
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 777, in generate
ERROR 08-06 23:49:35 grpc_server.py:145]     async for output in self._process_request(
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 893, in _process_request
ERROR 08-06 23:49:35 grpc_server.py:145]     raise e
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 889, in _process_request
ERROR 08-06 23:49:35 grpc_server.py:145]     async for request_output in stream:
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 94, in __anext__
ERROR 08-06 23:49:35 grpc_server.py:145]     raise result
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 47, in _log_task_completion
ERROR 08-06 23:49:35 grpc_server.py:145]     return_value = task.result()
ERROR 08-06 23:49:35 grpc_server.py:145]                    ^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 642, in run_engine_loop
ERROR 08-06 23:49:35 grpc_server.py:145]     result = task.result()
ERROR 08-06 23:49:35 grpc_server.py:145]              ^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 585, in engine_step
ERROR 08-06 23:49:35 grpc_server.py:145]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-06 23:49:35 grpc_server.py:145]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 259, in step_async
ERROR 08-06 23:49:35 grpc_server.py:145]     request_outputs = self._process_model_outputs(
ERROR 08-06 23:49:35 grpc_server.py:145]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/llm_engine.py", line 831, in _process_model_outputs
ERROR 08-06 23:49:35 grpc_server.py:145]     self.output_processor.process_outputs(seq_group, outputs)
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/output_processor/multi_step.py", line 90, in process_outputs
ERROR 08-06 23:49:35 grpc_server.py:145]     self._process_seq_outputs(seq, valid_samples,
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/output_processor/multi_step.py", line 131, in _process_seq_outputs
ERROR 08-06 23:49:35 grpc_server.py:145]     new_char_count = self.detokenizer.decode_sequence_inplace(
ERROR 08-06 23:49:35 grpc_server.py:145]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/transformers_utils/detokenizer.py", line 150, in decode_sequence_inplace
ERROR 08-06 23:49:35 grpc_server.py:145]     (_, new_text, _, _) = detokenize_incrementally(
ERROR 08-06 23:49:35 grpc_server.py:145]                           ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145]   File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/transformers_utils/detokenizer.py", line 287, in detokenize_incrementally
ERROR 08-06 23:49:35 grpc_server.py:145]     if new_token_id >= len(tokenizer):
ERROR 08-06 23:49:35 grpc_server.py:145]        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-06 23:49:35 grpc_server.py:145] TypeError: '>=' not supported between instances of 'NoneType  and 'int'

I traced this down to there being an entry in the logprobs of the sampler outputs that has None for the token id. The PR #6485, optimized speculative decoding by removing unnecessary operations on logprobs. This is configured with the disable_logprobs_during_spec_decoding flag, which is enabled by default (i.e. logprob computations are skipped by default). Instead of returning real logprobs, _create_dummy_logprob_lists creates dummy outputs with None entries for the top_k token ids and logprobs which leads to the error above in later processing. The change in this PR is to ignore these None entries in create_sequence_group_output.

FIX #6967

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Aug 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@tjohnson31415 tjohnson31415 changed the title [Bugfix] handle None entries in create_sequence_group_output [Bugfix] handle None entries in topk args in create_sequence_group_output Aug 7, 2024
@tjohnson31415 tjohnson31415 changed the title [Bugfix] handle None entries in topk args in create_sequence_group_output [Bugfix] spec decode handle None entries in topk args in create_sequence_group_output Aug 13, 2024
@tjohnson31415 tjohnson31415 force-pushed the spec-decoding-logprobs branch from e139f24 to 7d267e0 Compare August 20, 2024 19:37
@tjohnson31415 tjohnson31415 force-pushed the spec-decoding-logprobs branch from 7d267e0 to dcba1da Compare August 20, 2024 19:51
@tjohnson31415
Copy link
Contributor Author

tjohnson31415 commented Aug 20, 2024

Coming back to this PR:
I had this in draft while I investigated how to enable logprobs with spec decoding. It turns out I was getting confused by --disable-logprobs-during-spec-decoding not working as expected, i.e. --disable-logprobs-during-spec-decoding false would still result in the configuration being set to True and logprobs being disabled. The parsing of the flag's value has now been fixed with #7665.

i have also added a test to check the behavior with logprobs for spec decoding disabled, which would have caught this bug.

@tjohnson31415 tjohnson31415 marked this pull request as ready for review August 20, 2024 20:06
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 22, 2024
@njhill njhill merged commit cc0eaf1 into vllm-project:main Aug 22, 2024
46 checks passed
@tjohnson31415 tjohnson31415 deleted the spec-decoding-logprobs branch August 22, 2024 16:15
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: speculative decoding doesn't work with online mode
2 participants