Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpecDec] Remove Batch Expansion (2/3) #9298

Merged
merged 4 commits into from
Oct 12, 2024

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Oct 11, 2024

Follow up of #8839.

We will use flash_attn_varlen_func for MQA scorer. Therefore, we can support different propose lengths for different requests in the batch, which is essential for ngram and dynamic speculative decoding.

The following are some preliminary benchmark numbers with MQA scorer compared with batch expansion with/w/o cuda graph.

Screenshot 2024-10-21 at 4 28 47 PM

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LiuXiaoxuanPKU LiuXiaoxuanPKU requested review from comaniac and njhill and removed request for njhill October 11, 2024 21:44
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. What happen if we enable CUDA graph in this PR?

@LiuXiaoxuanPKU LiuXiaoxuanPKU added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 12, 2024
@LiuXiaoxuanPKU
Copy link
Collaborator Author

LGTM. What happen if we enable CUDA graph in this PR?

It will not work. During the graph capture time, we still go into the normal decoding path here. However, during the spec decode, we go into the varlen kernel branch here. Therefore, the graph captured cannot be used during the scoring phase.

@comaniac
Copy link
Collaborator

Yeah I know. I mean what users will see in this case and we should provide a proper error message.

@LiuXiaoxuanPKU
Copy link
Collaborator Author

Yeah I know. I mean what users will see in this case and we should provide a proper error message.

It will fall back to batch expansion if cuda graph is enabled as shown here.

@LiuXiaoxuanPKU LiuXiaoxuanPKU enabled auto-merge (squash) October 12, 2024 03:05
@LiuXiaoxuanPKU LiuXiaoxuanPKU merged commit 89feb4c into vllm-project:main Oct 12, 2024
58 checks passed
@wooyeonlee0
Copy link
Contributor

wooyeonlee0 commented Oct 21, 2024

Great! Thanks for your work @LiuXiaoxuanPKU
I've been really looking forward to this feature :)
After testing the MQA feature, I got unexpected results that were somewhat different from yours.
In my experiment setting, the MQA scorer does not perform faster than batch expansion without cuda-graph.

Below I've shared the details of my experiment, which is quite similar to yours.
I hope sharing this helps you to build a complete MQA scorer :)

=========
GPU: 1 H100 (tp=1)
Model: meta-llama/Llama-2-7b-chat-hf (draft: JackFram/llama-68m)
input_len: 550
output_len: 150
num_speculative_tokens: 3
temperature: 0
script: benchmark_latency.py
vllm version: v0.6.3

scoring time

batch_size batch expansion batch expansion w/o cuda graph mqa scorer w/o cuda graph
1 7.20 ms 9.88 ms 10.19 ms

Avg latency

batch_size batch expansion batch expansion w/o cuda graph mqa scorer w/o cuda graph
1 0.91 s 1.34 s 1.35 s

@LiuXiaoxuanPKU
Copy link
Collaborator Author

@wooyeonlee0
Thanks for the attention! yeah there are some errors in the previous benchmark results. I just updated the numbers.

For the last column MQA with cuda graph, we implemented a quick version (without handling edge cases) to test the performance. As seen from the numbers, we are thinking maybe instead of adding cuda graph support for MQA, which might introduce memory overhead because of the change of bucketing strategy, it might be easier to switch between MQA scorer and batch expansion scroer. For example, when batch_size < 8, use batch expansion scorer with cuda graph, when batch_size > 8, use MQA scorer without cudagraph. The switch should be relatively easy to implement without introducing overhead.

Any thoughts/comments/discussion are appreciated!

@wooyeonlee0
Copy link
Contributor

wooyeonlee0 commented Oct 23, 2024

@LiuXiaoxuanPKU Thanks for the quick response! :)

adding cuda graph support for MQA, which might introduce memory overhead because of the change of bucketing strategy

Does the memory overhead come from multiple cuda graphs for each different K value?
Then how about to implement the other version of MQA scorer that works with a fixed K value, as the current batch expansion also supports fixed K value? (I'm not sure it's possible.)
This fixed-K MQA scorer would support cuda graph without such memory overhead, so it can completely replace the batch expansion.

I think we can use the variable-K MQA scorer in this PR only when we want to use the dynamic speculative decoding feature.
What do you think?

@bettybaii
Copy link

bettybaii commented Oct 25, 2024

Thank you very much for your work @LiuXiaoxuanPKU 😊
When I disabled CUDA Graphs (by adding the --enforce-eager parameter), I observed a significant increase in the overhead of token proposals. Below, I have provided the details of my experiment:)

GPU: 1 A10 (tp=1)
Model: Meta-Llama-3-8B (draft: turboderp/Qwama-0.5B-Instruct)
input_len: 256
output_len: 10
batch_size: 16
num_speculative_tokens: 3
temperature: 0
script: benchmark_throughput.py
vllm version: v0.6.3.post2

When I did not add the --enforce-eager parameter (automatically using the batch expansion scorer):

INFO 10-25 02:19:22 spec_decode_worker.py:976] SpecDecodeWorker stage times: average_time_per_proposal_tok_ms=5.56 scoring_time_ms=43.28 verification_time_ms=1.52

When I added the --enforce-eager parameter (automatically using the MQA scorer):

INFO 10-25 02:26:19 spec_decode_worker.py:976] SpecDecodeWorker stage times: average_time_per_proposal_tok_ms=13.69 scoring_time_ms=41.28 verification_time_ms=1.49

It can be observed that with CUDA Graphs enabled, the average_time_per_proposal_tok_ms is only 5.56 ms, whereas, with CUDA Graphs disabled, it increases to 13.69 ms :(

In addition, I conducted separate tests on the inference performance of the draft model with and without CUDA Graphs enabled, as shown below:
with CUDA Graphs enabled: TPOP = 5.307 ms
without CUDA Graphs enabled: TPOP = 14.780 ms

Could you explain why enabling or disabling CUDA Graphs has such a significant impact on the draft model in my environment? Are there any solutions to resolve this issue?

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
@LiuXiaoxuanPKU
Copy link
Collaborator Author

Thank you very much for your work @LiuXiaoxuanPKU 😊 When I disabled CUDA Graphs (by adding the --enforce-eager parameter), I observed a significant increase in the overhead of token proposals. Below, I have provided the details of my experiment:)

GPU: 1 A10 (tp=1) Model: Meta-Llama-3-8B (draft: turboderp/Qwama-0.5B-Instruct) input_len: 256 output_len: 10 batch_size: 16 num_speculative_tokens: 3 temperature: 0 script: benchmark_throughput.py vllm version: v0.6.3.post2

When I did not add the --enforce-eager parameter (automatically using the batch expansion scorer):

INFO 10-25 02:19:22 spec_decode_worker.py:976] SpecDecodeWorker stage times: average_time_per_proposal_tok_ms=5.56 scoring_time_ms=43.28 verification_time_ms=1.52

When I added the --enforce-eager parameter (automatically using the MQA scorer):

INFO 10-25 02:26:19 spec_decode_worker.py:976] SpecDecodeWorker stage times: average_time_per_proposal_tok_ms=13.69 scoring_time_ms=41.28 verification_time_ms=1.49

It can be observed that with CUDA Graphs enabled, the average_time_per_proposal_tok_ms is only 5.56 ms, whereas, with CUDA Graphs disabled, it increases to 13.69 ms :(

In addition, I conducted separate tests on the inference performance of the draft model with and without CUDA Graphs enabled, as shown below: with CUDA Graphs enabled: TPOP = 5.307 ms without CUDA Graphs enabled: TPOP = 14.780 ms

Could you explain why enabling or disabling CUDA Graphs has such a significant impact on the draft model in my environment? Are there any solutions to resolve this issue?

yeah it totally makes sense. When draft model is small, we need cuda graph to achieve good performance. The cuda graph support for the draft model should always be on.

@LiuXiaoxuanPKU
Copy link
Collaborator Author

@LiuXiaoxuanPKU Thanks for the quick response! :)

adding cuda graph support for MQA, which might introduce memory overhead because of the change of bucketing strategy

Does the memory overhead come from multiple cuda graphs for each different K value? Then how about to implement the other version of MQA scorer that works with a fixed K value, as the current batch expansion also supports fixed K value? (I'm not sure it's possible.) This fixed-K MQA scorer would support cuda graph without such memory overhead, so it can completely replace the batch expansion.

I think we can use the variable-K MQA scorer in this PR only when we want to use the dynamic speculative decoding feature. What do you think?

The tricky part is that in ngram, it very likely that requests within the same batch might have different propose lengths (0 or k). If there is a match for a given request, the propose length will be k, otherwise 0. In this case, we cannot assume the same k for all requests in the batch.

@wooyeonlee0
Copy link
Contributor

@LiuXiaoxuanPKU Thanks for the detailed answer :) I didn't know that ngram has that kind of requirement..!
But we might add an additional version of MQA scorer with a fixed length to accelerate general cases..?

By the way, I have one more question about generation outputs.
When using the MQA scorer, the generated output is slightly different from that of the BatchExpansion scorer.
Do you have any idea about this?

While exploring this, I found that batching can also introduce this kind of issue due to floating-point errors. (#966)
Does speculative decoding also have the same issue?

Notes

  • the result using the BatchExpansion is also a bit different from that of the incremental decoding.
  • the result of each method (MQA, BatchExpansion, Incremental) is consistent by itself.

sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants