Skip to content

Commit

Permalink
hpu update for spec decode
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Nov 6, 2024
1 parent 7fdf503 commit 23037b4
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 7 deletions.
4 changes: 4 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

class HPUAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl
Expand Down
10 changes: 8 additions & 2 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@ def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is not None:
module_name = "vllm.spec_decode.spec_decode_worker"
class_name = "create_spec_worker"
else:
module_name = "vllm.worker.hpu_worker"
class_name = "HPUWorker"
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
worker_module_name=module_name,
worker_class_name=class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
Expand Down
62 changes: 62 additions & 0 deletions vllm/spec_decode/hpu_draft_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerBaseCls
from vllm.worker.hpu_model_runner import ModelInputForHPUWithSamplingMetadata

logger = init_logger(__name__)

# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True


class HPUTP1DraftModelRunner(ModelRunnerBaseCls):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(*args, **kwargs)

self.indices_of_seq_with_bonus_tokens = None

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForHPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if previous_hidden_states is not None:
_, block_size = model_input.input_tokens.shape
previous_hidden_states = previous_hidden_states.expand(
block_size, -1).unsqueeze(0)
return super().execute_model(
model_input=model_input,
kv_caches=kv_caches,
previous_hidden_states=previous_hidden_states,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
)
4 changes: 4 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
elif current_platform.is_hpu():
from vllm.spec_decode.hpu_draft_model_runner import (HPUTP1DraftModelRunner
as
TP1DraftModelRunner)

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
Expand Down
25 changes: 24 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def compute_logits(self, *args, **kwargs):
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)

def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

# sampler property will be used by spec_decode_worker
@property
def sampler(self):
return self.model.sampler


class PreparePromptMetadata(NamedTuple):
input_tokens: torch.Tensor
Expand Down Expand Up @@ -514,6 +522,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
Expand All @@ -539,13 +548,18 @@ def __init__(
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = self.cache_config.cache_dtype

num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
) if needs_attn_backend else None

# Lazy initialization
self.lora_manager: LRUCacheWorkerLoRAManager = None
Expand Down Expand Up @@ -1879,6 +1893,7 @@ def execute_model(
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
warmup_mode=False,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
Expand Down Expand Up @@ -1922,6 +1937,9 @@ def execute_model(
"lora_mask": lora_mask,
**(model_input.multi_modal_kwargs or {}),
}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs})

Expand Down Expand Up @@ -1987,6 +2005,11 @@ def execute_model(
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output]

def shutdown_inc(self):
Expand Down
27 changes: 23 additions & 4 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)

Expand All @@ -44,7 +43,7 @@ def __init__(
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
model_runner_cls: Optional[Type[HPUModelRunner]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
Expand All @@ -60,8 +59,28 @@ def __init__(
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()

self.model_runner: HPUModelRunner = HPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}

ModelRunnerClass: Type[HPUModelRunner] = HPUModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
else:
ModelRunnerClass = HPUModelRunner
self.model_runner: HPUModelRunner = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
Expand Down

0 comments on commit 23037b4

Please sign in to comment.