Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] [Spec decode]: Combine chunked prefill with speculative decoding #9291

Merged
merged 13 commits into from
Nov 7, 2024

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Oct 11, 2024

Hey, this PR implements #5016.

The main idea is to make use of the current Speculative Decoder workflow and integrate it with mixed prefill-decode batches.
In particular, we can run the batched prefills and decodes together through the scorer (with the usual prefill|decode layout supported by backend), while the proposer can sync its KV cache on prefills only.

image

Current attention kernel implementation still doesn't make full use of the prefill|decode, but once the MQA integration is finalized we can get an easy speedup by running the batch in a single forward.

Current implementation on main already is (to some extent) prefill aware, so I was able to re-use a good chunk of the logic and the changes aren't (purposely) drastic.
On the other hand, one could prioritize optimizations more and I am open to any suggestion on how to best implement the approach, even at the cost
of re-writing more parts and making the PR more invasive (ie breaking some of the interfaces to avoid duplication).

TODO:

  • benchmark on A/H100
  • expand test coverage with prefill chunking enabled
  • test with new mqa_scorer, current implementation was rebased from v0.6.2
  • fix speculative methods requiring return_hidden_states EDIT: on second thought, I believe this would be better addressed in a separate PR
  • disable_logprobs_during_spec_decoding compatibility

Update:

We add support for chunk prefill and spec decoding with the workflow depicted above, unless the proposer requires final hidden state from the target model (MLPSpeculator/Medusa): this is deferred to a second follow-up PR.

mqa_scorer is set to supersede BatchExpansion* thanks to the great work by @LiuXiaoxuanPKU, so we add support to that scorer directly in this PR!
Incidentally, this means enabling backend with flash_attn_varlen_func to take in any "mixed prefill-decode batch" in a single kernel call (so no more decoupled prefix-decode calls), which should also boost performance in "vanilla" chunked prefill scheduling policy (no spec).

Many thanks to @sroy745 for benchmarking the BatchExpansionTop1Scorer approach here (MQA to follow)!

Update 2:

After reviewing @sroy745 benchmarks, contrarily to expectations, fusing the two separate kernel call into a unified prefill+decode (single flash_attn_varlen_func call) did not yield improvements. I reverted the unifying kernel change, but I will keep the commit history here so we can come back to it and investigate some more on a separate optimization PR.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@NickLucche NickLucche marked this pull request as draft October 11, 2024 17:04
Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr. Left some comments. PTAL

vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/config.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@NickLucche NickLucche force-pushed the chunk-spec-decoding-rebase branch from 49b03ab to 8b88b8a Compare October 14, 2024 10:41
@arashsadrieh
Copy link

arashsadrieh commented Oct 15, 2024

@NickLucche Thanks for the great work and understand that is WIP, just small note while you are working on this piece

We tried this PR with tensor parallelism and we found that it throughs the following exception when we activate tensor parallelism:

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8083 --model /8b/  --speculative_model /1b/  --served-model-name SpeculativeLLM --tensor-parallel-size 4  --max-model-len 34336  --max-num-seqs 128  --enable-prefix-caching  --disable-log-requests --use-v2-block-manager --seed 42 --num_speculative_tokens 5  --spec-decoding-acceptance-method typical_acceptance_sampler  --enable_chunked_prefill

Here is the exception:

Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop: 'num_seq_groups', Traceback (most recent call last):
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/executor/multiproc_worker_utils.py", line 224, in _run_worker_process
     output = executor(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/opt/conda/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
     return func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/spec_decode/spec_decode_worker.py", line 459, in start_worker_execution_loop
     while self._run_non_driver_rank():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/spec_decode/spec_decode_worker.py", line 649, in _run_non_driver_rank
     self.proposer_worker.execute_model()
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 308, in execute_model
     inputs = self.prepare_input(execute_model_req)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 298, in prepare_input
     return self._get_worker_input_from_broadcast()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 240, in _get_worker_input_from_broadcast
     worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 151, in from_broadcasted_tensor_dict
     num_seq_groups=tensor_dict.pop("num_seq_groups"),
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 KeyError: 'num_seq_groups'

The following command works normally

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8083 --model /home/ec2-user/tengfei_workspace/output/8b-aio-20240923-3/merged/ --speculative_model /home/ec2-user/tengfei_workspace/output/1b-aio-20240923-3/merged/ --served-model-name SpeculativeLLM --tensor-parallel-size 1 --max-model-len 34336 --max-num-seqs 128 --enable-prefix-caching --disable-log-requests --use-v2-block-manager --seed 42 --num_speculative_tokens 5  --spec-decoding-acceptance-method typical_acceptance_sampler --enable_chunked_prefill --tensor-parallel-size 1

Thanks again and appreciate your work/ VLLM community

@NickLucche
Copy link
Contributor Author

NickLucche commented Oct 15, 2024

Thanks for testing that, will look right into it!
Might actually be related to prefix_caching, which I haven't taken into account yet (I know there's been some recent work on that too).

@NickLucche
Copy link
Contributor Author

Update on mqa_scorer integration: enable_chunked_prefill with changes in this PR appears to work fine with the flash_attn kernel prior to the optimized one introduced here #9298 (so flash_attn_with_kvcache instead of flash_attn_varlen_func). I will sync with @LiuXiaoxuanPKU on this.

@NickLucche NickLucche marked this pull request as ready for review October 17, 2024 15:45
vllm/config.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr. Left a few comments. PTAL.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
tests/spec_decode/test_spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/mqa_scorer.py Show resolved Hide resolved
tests/spec_decode/test_spec_decode_worker.py Show resolved Hide resolved
@NickLucche NickLucche force-pushed the chunk-spec-decoding-rebase branch from 0819d12 to 3e5b882 Compare October 22, 2024 09:15
Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! One comment about leaving out the unified kernel changes in this pr. Please check with @LiuXiaoxuanPKU and @comaniac on this. Otherwise LGTM.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
@NickLucche
Copy link
Contributor Author

Thanks for reviewing this!

Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr!! LGTM.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Show resolved Hide resolved
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 30, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Good job! Only nits.

vllm/config.py Outdated Show resolved Hide resolved
tests/utils.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @NickLucche please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: NickLucche <[email protected]>
@NickLucche NickLucche force-pushed the chunk-spec-decoding-rebase branch from cd9bd2a to ca2691e Compare October 31, 2024 17:49
@mergify mergify bot removed the needs-rebase label Oct 31, 2024
@NickLucche
Copy link
Contributor Author

Mmm apologies for the automatic call to review on so many people, had to sign commits and force push

@sroy745
Copy link
Collaborator

sroy745 commented Nov 1, 2024

@NickLucche I think you need to remove test_spec_decode_xfail_chunked_prefill from spec_decode/e2e/test_compatibility.py since its no longer applicable. Could you also please sync your branch to the head. It seems like some of the failures e.g. in buildkite/ci-aws/pr/decoder-only-multi-modal-models-test might already be fixed in head.

@NickLucche NickLucche force-pushed the chunk-spec-decoding-rebase branch from 756d33f to fb66563 Compare November 4, 2024 17:47
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NickLucche for the awesome work, and to @sroy745 @LiuXiaoxuanPKU @comaniac for the reviews

@njhill njhill merged commit 9d43afc into vllm-project:main Nov 7, 2024
57 checks passed
@andoorve
Copy link
Collaborator

andoorve commented Nov 7, 2024

Hi @NickLucche, thanks for the PR!

I tried with TP on the latest main. It seems like I get the same error as @arashsadrieh still. Is this expected to work?

KeyError: 'num_seq_groups'

@NickLucche
Copy link
Contributor Author

NickLucche commented Nov 7, 2024

Hey @andoorve, yeah tp for the target model should be working, iirc even @sroy745's benchmarks ran with tp=4. Unfortunately I do not have a way to test master right now as I am away :/

@sroy745
Copy link
Collaborator

sroy745 commented Nov 7, 2024

Hi @andoorve / @arashsadrieh
I am able to run with this pr with the following command

python3 -m vllm.entrypoints.openai.api_server --model "meta-llama/Meta-Llama-3-70B-Instruct" --tensor-parallel-size 4 --disable-log-requests --enable-chunked-prefill --max_num_batched_tokens 2048 --speculative_model turboderp/Qwama-0.5B-Instruct --num_speculative_tokens 1 --speculative_draft_tensor_parallel_size 1 --disable-custom-all-reduce --swap_space 16 --speculative_disable_mqa_scorer

What is the command you are using?

One difference I think is that in our evals we ran with the speculative model running with tp=1 and the target model running with tp=4. Can you try and see if that works for you?

@andoorve
Copy link
Collaborator

andoorve commented Nov 7, 2024

Hey @NickLucche @sroy745, this is what I'm using. I think this is the difference, as I'm running with TP > 1 on the draft model as well. Unfortunately the Llama 8B draft model that I want to use is relatively large for TP=1.

vllm serve meta-llama/Llama-3.1-405B-Instruct-FP8 --tensor-parallel-size 8 --max-num-seqs 32  --block-size 32  --speculative-model meta-llama/Llama-3.1-8B-Instruct  --num-speculative-tokens 8 --gpu-memory-utilization  0.98 --use-v2-block-manager --distributed-executor-backend ray --enable-chunked-prefill --max-num-batched-tokens 4096 --max-model-len 32768

@sroy745
Copy link
Collaborator

sroy745 commented Nov 8, 2024

Hey @NickLucche @sroy745, this is what I'm using. I think this is the difference, as I'm running with TP > 1 on the draft model as well. Unfortunately the Llama 8B draft model that I want to use is relatively large for TP=1.

vllm serve meta-llama/Llama-3.1-405B-Instruct-FP8 --tensor-parallel-size 8 --max-num-seqs 32  --block-size 32  --speculative-model meta-llama/Llama-3.1-8B-Instruct  --num-speculative-tokens 8 --gpu-memory-utilization  0.98 --use-v2-block-manager --distributed-executor-backend ray --enable-chunked-prefill --max-num-batched-tokens 4096 --max-model-len 32768

I will add a check to verify that sd + chunked-prefill is enabled for tp=1 draft model and then continue with the investigation. It is not breaking any existing cases so will add the check and debug.

Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Nov 8, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants