Skip to content

Commit

Permalink
Spec Decode - Remove hard-dependency on GPU
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Nov 6, 2024
1 parent 87bd7e0 commit 347f7ae
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 31 deletions.
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None:
dtype=torch.long,
device=device)

def init_tensors(self,
device: Union[int, str],
device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)

@property
def probs_dtype(self):
return torch.float32
Expand Down Expand Up @@ -77,7 +90,7 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
21 changes: 18 additions & 3 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,29 @@

from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls


class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls):
"""Worker for Medusa.
"""

Expand Down
9 changes: 9 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -81,8 +82,16 @@ def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()

def init_tensors(self, rank: int, device_type: str = 'cuda') -> None:
self._rank = rank
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
28 changes: 23 additions & 5 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,35 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker


class MultiStepWorker(Worker, ProposerWorkerBase):
if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls
elif current_platform.is_openvino():
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls
else:
from vllm.worker.worker import Worker as WorkerBaseCls


class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand Down Expand Up @@ -75,7 +93,7 @@ def sampler_output(

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
Expand Down
40 changes: 30 additions & 10 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand All @@ -36,9 +40,23 @@
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerCls
elif current_platform.is_openvino:
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerCls
else:
from vllm.worker.worker import Worker as WorkerCls

logger = init_logger(__name__)


Expand All @@ -53,7 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy()

kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
target_worker = WorkerCls(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
Expand Down Expand Up @@ -125,7 +143,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod
def create_worker(
cls,
scorer_worker: Worker,
scorer_worker: WorkerCls,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
Expand Down Expand Up @@ -158,8 +176,9 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down Expand Up @@ -306,8 +325,9 @@ def init_device(self) -> None:
self.scorer_worker.load_model()
self.proposer_worker.load_model()

self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self._metrics.init_tensors(self.rank, device_type=self.device.type)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device.type)

scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
Expand Down Expand Up @@ -1044,11 +1064,11 @@ def get_cache_block_size_bytes(self):
raise NotImplementedError

def start_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.start_profile()

def stop_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerCls):
self.scorer_worker.stop_profile()


Expand Down
45 changes: 37 additions & 8 deletions vllm/spec_decode/target_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
from typing import List, Optional

from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)

if current_platform.is_cuda_alike():
from vllm.worker.model_runner import (
ModelInputForGPUWithSamplingMetadata as ModelInputCls) # yapf: disable
from vllm.worker.model_runner import ModelRunner as ModelRunnerCls
elif current_platform.is_neuron():
from vllm.worker.neuron_model_runner import (
ModelInputForNeuron as ModelInputCls) # yapf: disable
from vllm.worker.neuron_model_runner import (
NeuronModelRunner as ModelRunnerCls) # yapf: disable
elif current_platform.is_hpu():
from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerCls
from vllm.worker.hpu_model_runner import (
ModelInputForHPUWithSamplingMetadata as ModelInputCls) # yapf: disable
elif current_platform.is_openvino:
from vllm.worker.openvino_model_runner import ModelInput as ModelInputCls
from vllm.worker.openvino_model_runner import (
OpenVINOModelRunner as ModelRunnerCls) # yapf: disable
elif current_platform.is_cpu():
from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerCls
from vllm.worker.cpu_model_runner import (
ModelInputForCPUWithSamplingMetadata as ModelInputCls) # yapf: disable
elif current_platform.is_tpu():
from vllm.worker.tpu_model_runner import ModelInputForTPU as ModelInputCls
from vllm.worker.tpu_model_runner import TPUModelRunner as ModelRunnerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_model_runner import (
ModelInputForXPUWithSamplingMetadata as ModelInputCls) # yapf: disable
from vllm.worker.xpu_model_runner import XPUModelRunner as ModelRunnerCls
else:
raise ValueError(f"Unsupported platform: {current_platform}")

class TargetModelRunner(ModelRunner):

class TargetModelRunner(ModelRunnerCls):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
Expand Down Expand Up @@ -39,11 +69,10 @@ def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputCls:
model_input: ModelInputCls = super().prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
Expand Down
12 changes: 8 additions & 4 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)
Expand Down Expand Up @@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
Arguments:
msg (string): message to associate with the range
"""
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
if current_platform.is_cuda_alike():
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
else:
yield
finally:
torch.cuda.nvtx.range_pop()


class Timer:
Expand Down

0 comments on commit 347f7ae

Please sign in to comment.