Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Fixing hidden states handling in batch expansion #7508

Merged
merged 6 commits into from
Aug 20, 2024

Conversation

abhigoyal1997
Copy link
Contributor

This PR fixes the handling of hidden_states in batch expansion.

Fix #7505

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@abhigoyal1997
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 14, 2024
@cadedaniel
Copy link
Collaborator

  • can you add a test?
  • what exactly does this fix? were the hidden states applied incorrectly?

@abhigoyal1997
Copy link
Contributor Author

abhigoyal1997 commented Aug 15, 2024

  • can you add a test?

Would an E2E test be sufficient? I can add the example I mentioned in the issue where we see the error.

  • what exactly does this fix? were the hidden states applied incorrectly?

There was an error in a specific scenario. Consider sequences for which spec. decode is disabled (e.g., when num_tokens + spec_tokens > max_model_len). For such sequences, proposal length is set to 0 and they are handled separately from other sequences. Then when contracting the batch (in BatchExpansionTop1Scorer._contract_batch), we merge spec and non_spec sequences into shape (batch_size, spec_length + 1). While this was done correctly for other tensors, hidden_states were kept unchanged before and only reshaped later on:

hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,

This reshaping works fine if all sequences have proposal length > 0, but if for any sequence spec decode is disabled, it doesn't work. This PR changes batch expansion (mainly BatchExpansionTop1Scorer._contract_batch) such that hidden_states are manipulated and arranged like other tensors like tokens, probs and logprobs. This ensures hidden_states are correctly updated in all cases.

@cadedaniel
Copy link
Collaborator

Thanks for the writeup.

Would an E2E test be sufficient? I can add the example I mentioned in the issue where we see the error.

An E2E test is good. I think the issue here is that we should validate an accuracy level, not just that it doesn't crash. For this you can build on this PR #6454 which adds an assertion that draft acceptance rate is 100%.

@abhigoyal1997
Copy link
Contributor Author

abhigoyal1997 commented Aug 16, 2024

An E2E test is good. I think the issue here is that we should validate an accuracy level, not just that it doesn't crash. For this you can build on this PR #6454 which adds an assertion that draft acceptance rate is 100%.

Makes sense, but any draft model which uses hidden states can't be identical to the target model. In that case I don't think we can get draft acceptance rate of 100%. What do you suggest then?

@cadedaniel
Copy link
Collaborator

You can run it for the test prompts and record the accuracy, then assert it doesn’t go below that fixed value plus some epsilon.

@abhigoyal1997
Copy link
Contributor Author

abhigoyal1997 commented Aug 17, 2024

@cadedaniel I've added the test along with other MLPSpeculator tests.

@njhill njhill self-requested a review August 17, 2024 18:13
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @abhigoyal1997!

@njhill
Copy link
Member

njhill commented Aug 19, 2024

@cadedaniel you good with this being merged?

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, good to merge! thanks for the great testing here

temperature=0.0,
seeded=True,
force_output_len=True,
expected_acceptance_rate=0.48)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill do you know what value we should expect here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(like, is this ballpark correct)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cadedaniel the value looks reasonable but I’m actually not sure what’s expected, will try to find out when I get a chance.

@cadedaniel cadedaniel merged commit 312f761 into vllm-project:main Aug 20, 2024
46 checks passed
zifeitong pushed a commit to zifeitong/vllm that referenced this pull request Aug 20, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Error in how HiddenStates are handled for speculative decoding
3 participants