-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel]Generalize Speculative decode from Cuda #10094
Changes from 5 commits
e78d570
ce2665c
2f393a1
707c149
9a3bd16
9162eef
17e1aa5
3a4e912
d6e2b05
73916b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.sampler import SamplerOutput | ||
from vllm.sequence import IntermediateTensors | ||
from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerBaseCls | ||
from vllm.worker.cpu_model_runner import ModelInputForCPUWithSamplingMetadata | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class CPUTP1DraftModelRunner(ModelRunnerBaseCls): | ||
"""Specialized model runner for speculative decoding draft model. | ||
Since the draft model always execute k forward passes consecutively to | ||
generate k speculative tokens in a single speculative decoding step, | ||
we could get rid of most CPU-GPU synchronization and data transfer | ||
overheads by keeping model input and output tensors on GPU all the time. | ||
TODOs: | ||
1. Support TP > 1 (this requires some designs because we do not expect | ||
any broadcasting inside execute_model). | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
if kwargs.get("return_hidden_states"): | ||
raise ValueError( | ||
"return_hidden_states is not supported for TP1DraftModelRunner." | ||
) | ||
super().__init__(*args, **kwargs) | ||
self.indices_of_seq_with_bonus_tokens = None | ||
|
||
@torch.inference_mode() | ||
def execute_model( | ||
self, | ||
model_input: ModelInputForCPUWithSamplingMetadata, | ||
kv_caches: List[torch.Tensor], | ||
previous_hidden_states: Optional[torch.Tensor] = None, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
num_steps: int = 1, | ||
) -> Optional[List[SamplerOutput]]: | ||
return super().execute_model( | ||
model_input=model_input, | ||
kv_caches=kv_caches, | ||
previous_hidden_states=previous_hidden_states, | ||
intermediate_tensors=intermediate_tensors, | ||
num_steps=num_steps, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,18 @@ | |
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) | ||
from vllm.model_executor.layers.typical_acceptance_sampler import ( | ||
TypicalAcceptanceSampler) | ||
from vllm.platforms import current_platform | ||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, | ||
CompletionSequenceGroupOutput, ExecuteModelRequest, | ||
HiddenStates, SequenceGroupMetadata, | ||
get_all_seq_ids_and_request_ids) | ||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer | ||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner | ||
|
||
if current_platform.is_cuda_alike(): | ||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner | ||
elif current_platform.is_cpu(): | ||
from vllm.spec_decode.cpu_draft_model_runner import CPUTP1DraftModelRunner | ||
|
||
xuechendi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from vllm.spec_decode.interfaces import (SpeculativeProposals, | ||
SpeculativeScorer, SpeculativeScores) | ||
from vllm.spec_decode.medusa_worker import MedusaWorker | ||
|
@@ -36,9 +42,23 @@ | |
get_all_num_logprobs, | ||
get_sampled_token_logprobs, nvtx_range, | ||
split_batch_by_proposal_len) | ||
from vllm.worker.worker import Worker | ||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase | ||
|
||
if current_platform.is_neuron(): | ||
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls | ||
elif current_platform.is_hpu(): | ||
from vllm.worker.hpu_worker import HPUWorker as WorkerCls | ||
elif current_platform.is_openvino(): | ||
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls | ||
elif current_platform.is_cpu(): | ||
from vllm.worker.cpu_worker import CPUWorker as WorkerCls | ||
elif current_platform.is_tpu(): | ||
from vllm.worker.tpu_worker import TPUWorker as WorkerCls | ||
elif current_platform.is_xpu(): | ||
from vllm.worker.xpu_worker import XPUWorker as WorkerCls | ||
else: | ||
from vllm.worker.worker import Worker as WorkerCls | ||
xuechendi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
|
@@ -53,7 +73,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": | |
draft_worker_kwargs = kwargs.copy() | ||
|
||
kwargs["model_runner_cls"] = TargetModelRunner | ||
target_worker = Worker(*args, **kwargs) | ||
target_worker = WorkerCls(*args, **kwargs) | ||
# Set the disable_logprobs variable in the TargetModelRunner instance | ||
# as per its value specified in the SpeculativeConfig. | ||
target_worker.model_runner.disable_logprobs =\ | ||
|
@@ -125,7 +145,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): | |
@classmethod | ||
def create_worker( | ||
cls, | ||
scorer_worker: Worker, | ||
scorer_worker: WorkerCls, | ||
draft_worker_kwargs: Dict[str, Any], | ||
disable_mqa_scorer: bool, | ||
disable_by_batch_size: Optional[int], | ||
|
@@ -158,8 +178,15 @@ def create_worker( | |
proposer_worker = MedusaWorker(**draft_worker_kwargs) | ||
else: | ||
if draft_tp == 1: | ||
draft_worker_kwargs[ | ||
"model_runner_cls"] = TP1DraftModelRunner | ||
if current_platform.is_cuda_alike(): | ||
draft_worker_kwargs[ | ||
"model_runner_cls"] = TP1DraftModelRunner | ||
elif current_platform.is_cpu(): | ||
draft_worker_kwargs[ | ||
"model_runner_cls"] = CPUTP1DraftModelRunner | ||
else: | ||
raise NotImplementedError( | ||
"current platform does not support EAGLE.") | ||
xuechendi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
if draft_model_config.hf_config.model_type == "eagle": | ||
raise NotImplementedError( | ||
|
@@ -306,8 +333,9 @@ def init_device(self) -> None: | |
self.scorer_worker.load_model() | ||
self.proposer_worker.load_model() | ||
|
||
self._metrics.init_gpu_tensors(self.rank) | ||
self.spec_decode_sampler.init_gpu_tensors(self.rank) | ||
self._metrics.init_tensors(self.rank, device_type=self.device.type) | ||
self.spec_decode_sampler.init_tensors(self.rank, | ||
device_type=self.device.type) | ||
|
||
scorer_cls: Type[SpeculativeScorer] | ||
if self.disable_mqa_scorer: | ||
|
@@ -320,7 +348,7 @@ def init_device(self) -> None: | |
"[Speculative Decoding] Use MQA scorer for scoring proposals.") | ||
|
||
self.scorer = scorer_cls(scorer_worker=self.scorer_worker, | ||
device=self.device, | ||
device=self.device.type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The argument is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @comaniac , the reason I changed that is because the device type is str in Scorer_cls init, but for some reason, it passed device=> so it failed mypy test https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/interfaces.py#L78-L79 |
||
vocab_size=self._vocab_size) | ||
|
||
self._configure_model_sampler_for_spec_decode() | ||
|
@@ -1090,11 +1118,11 @@ def get_cache_block_size_bytes(self): | |
raise NotImplementedError | ||
|
||
def start_profile(self): | ||
if isinstance(self.scorer_worker, Worker): | ||
if isinstance(self.scorer_worker, WorkerCls): | ||
self.scorer_worker.start_profile() | ||
|
||
def stop_profile(self): | ||
if isinstance(self.scorer_worker, Worker): | ||
if isinstance(self.scorer_worker, WorkerCls): | ||
self.scorer_worker.stop_profile() | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a clean and concise way to support non CUDA workers, so apparently you'll need some designs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@comaniac , I could put a worker_selector.py in either worker folder or in spec_decode folder, I didn't do that was because when I discussed this with @LiuXiaoxuanPKU , she prefer to keep this PR as simple as possible.
Would like your opinion here? The idea is that, I can extract above codes into a new file, and in spec_decode_worker, medusa_worker, simply do "from vllm.worker.selector import WorkerCls"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is I don't think the current PR is simple, given that this logic is tedious and duplicated everywhere. I'm also not sure if this is reliable to derive classes based on a dynamic variable (i.e.
current_platform
) in a distributed environment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @comaniac, do you mean support for heterogeneous platform in spec decode path?
Yeah, I totally Agree that current codes are tedious, do you think extract the worker_selector into a single file to simplify the codes works? or do you have other suggestion?
I am totally open to discuss about the design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mean to support heterogeneous platform. I just feel
class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls)
that derives a dynamicWorkerCls
seems not trivial and not sure if this is safe and reliable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@comaniac, I see, alternatively, I can add all necessary API to worker_base.py and make medusa_worker / multi_step_worker and others derive from "WorkerBase" instead of "Worker"?
But the change will be tremendous that is why I am not sure If I should do that.
I tested with current way of using 'dynamic WorkerCls', it is working on CUDA and CPU, also works for HPU in my own dev.
So I considered it as a valid solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@comaniac , I updated this PR, now WorkerCls is added to "vllm/spec_decode/selector.py" instead of spreading them all around. Please check if this looks better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@comaniac , I verified with distributed case as well using test below