Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel]Generalize Speculative decode from Cuda #10094

Closed
wants to merge 10 commits into from
8 changes: 6 additions & 2 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ def _create_worker(
local_rank: int = 0,
rank: int = 0,
):
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None:
dtype=torch.long,
device=device)

def init_tensors(self,
device: Union[int, str],
device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)

@property
def probs_dtype(self):
return torch.float32
Expand Down Expand Up @@ -77,7 +90,7 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
48 changes: 48 additions & 0 deletions vllm/spec_decode/cpu_draft_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import List, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerBaseCls
from vllm.worker.cpu_model_runner import ModelInputForCPUWithSamplingMetadata

logger = init_logger(__name__)


class CPUTP1DraftModelRunner(ModelRunnerBaseCls):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(*args, **kwargs)
self.indices_of_seq_with_bonus_tokens = None

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
return super().execute_model(
model_input=model_input,
kv_caches=kv_caches,
previous_hidden_states=previous_hidden_states,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
)
21 changes: 18 additions & 3 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,29 @@

from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a clean and concise way to support non CUDA workers, so apparently you'll need some designs.

Copy link
Contributor Author

@xuechendi xuechendi Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I could put a worker_selector.py in either worker folder or in spec_decode folder, I didn't do that was because when I discussed this with @LiuXiaoxuanPKU , she prefer to keep this PR as simple as possible.

Would like your opinion here? The idea is that, I can extract above codes into a new file, and in spec_decode_worker, medusa_worker, simply do "from vllm.worker.selector import WorkerCls"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is I don't think the current PR is simple, given that this logic is tedious and duplicated everywhere. I'm also not sure if this is reliable to derive classes based on a dynamic variable (i.e. current_platform) in a distributed environment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @comaniac, do you mean support for heterogeneous platform in spec decode path?
Yeah, I totally Agree that current codes are tedious, do you think extract the worker_selector into a single file to simplify the codes works? or do you have other suggestion?

I am totally open to discuss about the design.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mean to support heterogeneous platform. I just feel class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls) that derives a dynamic WorkerCls seems not trivial and not sure if this is safe and reliable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac, I see, alternatively, I can add all necessary API to worker_base.py and make medusa_worker / multi_step_worker and others derive from "WorkerBase" instead of "Worker"?
But the change will be tremendous that is why I am not sure If I should do that.

I tested with current way of using 'dynamic WorkerCls', it is working on CUDA and CPU, also works for HPU in my own dev.
So I considered it as a valid solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I updated this PR, now WorkerCls is added to "vllm/spec_decode/selector.py" instead of spreading them all around. Please check if this looks better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac , I verified with distributed case as well using test below

CUDA_VISIBLE_DEVICES=0,1 pytest -v tests/spec_decode/e2e/test_integration_dist_tp2.py::test_draft_model_tp_lt_target_model_tp2



class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls):
"""Worker for Medusa.
"""

Expand Down
9 changes: 9 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -81,8 +82,16 @@ def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()

def init_tensors(self, rank: int, device_type: str = 'cuda') -> None:
self._rank = rank
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
28 changes: 23 additions & 5 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,35 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MultiStepWorker(Worker, ProposerWorkerBase):
if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls
else:
from vllm.worker.worker import Worker as WorkerBaseCls


class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand Down Expand Up @@ -75,7 +93,7 @@ def sampler_output(

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
Expand Down
20 changes: 19 additions & 1 deletion vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,29 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer

if current_platform.is_cuda_alike():
DEVICE_TYPE = "cuda"
elif current_platform.is_neuron():
DEVICE_TYPE = "neuron"
elif current_platform.is_hpu():
DEVICE_TYPE = "hpu"
elif current_platform.is_openvino():
DEVICE_TYPE = "openvino"
elif current_platform.is_cpu():
DEVICE_TYPE = "cpu"
elif current_platform.is_tpu():
DEVICE_TYPE = "tpu"
elif current_platform.is_xpu():
DEVICE_TYPE = "xpu"
else:
raise ValueError(f"Unsupported platform: {current_platform}")


class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Expand All @@ -34,7 +52,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None

# Current NGramWorker only supports Top1Proposer
Expand Down
50 changes: 39 additions & 11 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
elif current_platform.is_cpu():
from vllm.spec_decode.cpu_draft_model_runner import CPUTP1DraftModelRunner

xuechendi marked this conversation as resolved.
Show resolved Hide resolved
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand All @@ -36,9 +42,23 @@
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls
xuechendi marked this conversation as resolved.
Show resolved Hide resolved

logger = init_logger(__name__)


Expand All @@ -53,7 +73,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy()

kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
target_worker = WorkerCls(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
Expand Down Expand Up @@ -125,7 +145,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod
def create_worker(
cls,
scorer_worker: Worker,
scorer_worker: WorkerCls,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
Expand Down Expand Up @@ -158,8 +178,15 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
elif current_platform.is_cpu():
draft_worker_kwargs[
"model_runner_cls"] = CPUTP1DraftModelRunner
else:
raise NotImplementedError(
"current platform does not support EAGLE.")
xuechendi marked this conversation as resolved.
Show resolved Hide resolved
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down Expand Up @@ -306,8 +333,9 @@ def init_device(self) -> None:
self.scorer_worker.load_model()
self.proposer_worker.load_model()

self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self._metrics.init_tensors(self.rank, device_type=self.device.type)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device.type)

scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
Expand All @@ -320,7 +348,7 @@ def init_device(self) -> None:
"[Speculative Decoding] Use MQA scorer for scoring proposals.")

self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
device=self.device.type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is device so you shouldn't pass "device type". You could take the device type in scorer_cls and don't need to change this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @comaniac , the reason I changed that is because the device type is str in Scorer_cls init, but for some reason, it passed device=> so it failed mypy test

https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/interfaces.py#L78-L79

vocab_size=self._vocab_size)

self._configure_model_sampler_for_spec_decode()
Expand Down Expand Up @@ -1090,11 +1118,11 @@ def get_cache_block_size_bytes(self):
raise NotImplementedError

def start_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.start_profile()

def stop_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.stop_profile()


Expand Down
Loading