Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel]Generalize Speculative decode from Cuda #10094

Closed
wants to merge 10 commits into from

Conversation

xuechendi
Copy link
Contributor

@xuechendi xuechendi commented Nov 6, 2024

This PR is mainly target to remove hard dependency for CUDA in speculative decoding

Done:

  1. Remove hard dependency and select worker / modelRunner based on current_platform
  2. per mgoin's suggestion, Enabled CPU support for Speculative Decoding

Based on discussion with @comaniac and @youkaichao , I provide a Second solution to avoid Dynamic WorkerCls => #10587


Settings:

  • draft model
    llm = LLM(
        model="facebook/opt-1.3b",
        speculative_model="facebook/opt-125m",
        num_speculative_tokens=5,
        use_v2_block_manager=True,
    )
  • medusa
    llm = LLM(
        model="JackFram/llama-68m",
        speculative_model="abhigoyal/vllm-medusa-llama-68m-random",
        num_speculative_tokens=4,
        use_v2_block_manager=True,
    )
  • eagle
llm = LLM(
        model="JackFram/llama-68m",
        speculative_model="abhigoyal/vllm-eagle-llama-68m-random",
        num_speculative_tokens=5,
        use_v2_block_manager=True
    )
  • mlp
    llm = LLM(
        model="JackFram/llama-160m",
        speculative_model="ibm-fms/llama-160m-accelerator",
        num_speculative_tokens=3,
        use_v2_block_manager=True
    )

*ngram

    llm = LLM(
        model="facebook/opt-350m",
        speculative_model="[ngram]",
        num_speculative_tokens=5,
        ngram_prompt_lookup_max=3,
        use_v2_block_manager=True,
    )

Copy link

github-actions bot commented Nov 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@xuechendi
Copy link
Contributor Author

xuechendi commented Nov 6, 2024

Hi, @LiuXiaoxuanPKU, may you take a look of this PR
I want to remove the hard dependency in speculative decoding.

@xuechendi xuechendi force-pushed the spec_decode_detach_hw branch 4 times, most recently from 16a98e1 to 23037b4 Compare November 6, 2024 23:19
@xuechendi
Copy link
Contributor Author

xuechendi commented Nov 7, 2024

Hi, @simon-mo, will you check on this PR?

Copy link

mergify bot commented Nov 7, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 7, 2024
@xuechendi xuechendi force-pushed the spec_decode_detach_hw branch from 615ea18 to cdd0471 Compare November 7, 2024 16:51
@mergify mergify bot removed the needs-rebase label Nov 7, 2024
@xuechendi xuechendi closed this Nov 7, 2024
@xuechendi xuechendi reopened this Nov 7, 2024
@xuechendi xuechendi force-pushed the spec_decode_detach_hw branch from 6247f29 to 1ea5684 Compare November 7, 2024 20:21
@xuechendi
Copy link
Contributor Author

xuechendi commented Nov 7, 2024

@WoosukKwon , will you take a look at this PR?

@mgoin
Copy link
Member

mgoin commented Nov 7, 2024

Although it may not be practical due to the lack of compute intensity, it would be helpful for testing of the generalization to have a CPU implementation to more easily test non-CUDA

@xuechendi
Copy link
Contributor Author

@mgoin , CPU supported for spec decode is added. Please help to take a review

@xuechendi xuechendi changed the title Generalize Speculative decode from Cuda [Kernel]Generalize Speculative decode from Cuda Nov 8, 2024
@xuechendi
Copy link
Contributor Author

@cadedaniel , may you take a look of this PR. I would like to remove hard-dependencies for spec decode to CUDA, so we can apply to other hardware

@xuechendi xuechendi force-pushed the spec_decode_detach_hw branch from 4337679 to 77ac59a Compare November 8, 2024 16:35
Copy link

mergify bot commented Nov 11, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2024
@cadedaniel
Copy link
Collaborator

@cadedaniel , may you take a look of this PR. I would like to remove hard-dependencies for spec decode to CUDA, so we can apply to other hardware

Can you share the performance improvement on AMD hardware? Cc @LiuXiaoxuanPKU @comaniac

@xuechendi
Copy link
Contributor Author

xuechendi commented Nov 12, 2024

@cadedaniel , may you take a look of this PR. I would like to remove hard-dependencies for spec decode to CUDA, so we can apply to other hardware

Can you share the performance improvement on AMD hardware? Cc @LiuXiaoxuanPKU @comaniac

@cadedaniel , thanks for reviewing this PR. I aimed to use this PR to firstly make it possible to run Spec Decode on other HW besides GPU.
Performance wisely, I believe different HW may need special treatment to get the optimal performance (so Maybe we can do that on another PR?) => Adding CPU support here is only to show case all hard-dependencies on GPU is cleaned. so this PR might not be the best impl for CPU

FYI, we have another proposal to provide heterogenous setup which runs draft model on CPU and target model on GPU. We can discuss about that later which may be better use case for running spec on CPU.

@xuechendi
Copy link
Contributor Author

Hi, @njhill , I just learned you are one of owners for spec decode, may you help to take a review on this PR?

@xuechendi xuechendi force-pushed the spec_decode_detach_hw branch from 77ac59a to 9a3bd16 Compare November 15, 2024 22:11
@mergify mergify bot removed the needs-rebase label Nov 15, 2024
Copy link

mergify bot commented Nov 20, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 20, 2024
Comment on lines 14 to 27
if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a clean and concise way to support non CUDA workers, so apparently you'll need some designs.

Copy link
Contributor Author

@xuechendi xuechendi Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I could put a worker_selector.py in either worker folder or in spec_decode folder, I didn't do that was because when I discussed this with @LiuXiaoxuanPKU , she prefer to keep this PR as simple as possible.

Would like your opinion here? The idea is that, I can extract above codes into a new file, and in spec_decode_worker, medusa_worker, simply do "from vllm.worker.selector import WorkerCls"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is I don't think the current PR is simple, given that this logic is tedious and duplicated everywhere. I'm also not sure if this is reliable to derive classes based on a dynamic variable (i.e. current_platform) in a distributed environment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @comaniac, do you mean support for heterogeneous platform in spec decode path?
Yeah, I totally Agree that current codes are tedious, do you think extract the worker_selector into a single file to simplify the codes works? or do you have other suggestion?

I am totally open to discuss about the design.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mean to support heterogeneous platform. I just feel class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls) that derives a dynamic WorkerCls seems not trivial and not sure if this is safe and reliable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac, I see, alternatively, I can add all necessary API to worker_base.py and make medusa_worker / multi_step_worker and others derive from "WorkerBase" instead of "Worker"?
But the change will be tremendous that is why I am not sure If I should do that.

I tested with current way of using 'dynamic WorkerCls', it is working on CUDA and CPU, also works for HPU in my own dev.
So I considered it as a valid solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I updated this PR, now WorkerCls is added to "vllm/spec_decode/selector.py" instead of spreading them all around. Please check if this looks better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I verified with distributed case as well using test below

CUDA_VISIBLE_DEVICES=0,1 pytest -v tests/spec_decode/e2e/test_integration_dist_tp2.py::test_draft_model_tp_lt_target_model_tp2

@mergify mergify bot removed the needs-rebase label Nov 20, 2024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_cls_selector.py may be a better name for this.

Can we wrap the logic to an API? For example

def get_worker_cls_by_platform():
    ...

In general this is still not the best practice, but I don't have a better solution atm.
cc @youkaichao

vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
@@ -320,7 +348,7 @@ def init_device(self) -> None:
"[Speculative Decoding] Use MQA scorer for scoring proposals.")

self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
device=self.device.type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is device so you shouldn't pass "device type". You could take the device type in scorer_cls and don't need to change this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @comaniac , the reason I changed that is because the device type is str in Scorer_cls init, but for some reason, it passed device=> so it failed mypy test

https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/interfaces.py#L78-L79

vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
ModelInputForNeuron as ModelInputCls)
from vllm.worker.neuron_model_runner import ( # noqa: F401
NeuronModelRunner as ModelRunnerCls)
from vllm.worker.neuron_worker import ( # noqa: F401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I actually plan to add some arguments like --worker-cls auto and let every platform select there own worker class. we should do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao, is something I can refer to? Or is this file works, currently, I put it under spec_decode folder, it also makes sense to put under worker folder.

@xuechendi
Copy link
Contributor Author

@comaniac , I resolved most of your comments, and left two TODOs:

  1. change 'device.type' back to 'device'. The reason I changed to 'device.type' is a type fix captured during mypy check, SpeculativeScorer init function requires device type as 'str', change back to 'device' failed mypy check.
  2. define get_worker_cls_by_platform(): in selector.py => I saw Kaikao said he has some plan on that, I'll check with him so I left the selector.py unchanged at this moment.

@comaniac
Copy link
Collaborator

  1. define get_worker_cls_by_platform(): in selector.py => I saw Kaikao said he has some plan on that, I'll check with him so I left the selector.py unchanged at this moment.

#10555 should fix this.

Copy link

mergify bot commented Nov 22, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 22, 2024
@xuechendi
Copy link
Contributor Author

Thanks, @comaniac , I created a new PR to use WorkerWrapperBase instead of Dynamic WorkerCls. Please is => #10587

@xuechendi xuechendi closed this Nov 25, 2024
@xuechendi xuechendi deleted the spec_decode_detach_hw branch December 19, 2024 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants