diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a8f4b09b67274..1197afd15d016 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -22,6 +22,10 @@ class HPUAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "hpu-attn" + @staticmethod def get_impl_cls() -> Type["HPUAttentionImpl"]: return HPUAttentionImpl diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 220e9eee87bb3..90e667e1965ec 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -48,9 +48,15 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): + if self.speculative_config is not None: + module_name = "vllm.spec_decode.spec_decode_worker" + class_name = "create_spec_worker" + else: + module_name = "vllm.worker.hpu_worker" + class_name = "HPUWorker" wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.hpu_worker", - worker_class_name="HPUWorker", + worker_module_name=module_name, + worker_class_name=class_name, ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) diff --git a/vllm/spec_decode/hpu_draft_model_runner.py b/vllm/spec_decode/hpu_draft_model_runner.py new file mode 100644 index 0000000000000..873b1970f6798 --- /dev/null +++ b/vllm/spec_decode/hpu_draft_model_runner.py @@ -0,0 +1,62 @@ +from typing import List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import IntermediateTensors +from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerBaseCls +from vllm.worker.hpu_model_runner import ModelInputForHPUWithSamplingMetadata + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class HPUTP1DraftModelRunner(ModelRunnerBaseCls): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + + super().__init__(*args, **kwargs) + + self.indices_of_seq_with_bonus_tokens = None + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForHPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if previous_hidden_states is not None: + _, block_size = model_input.input_tokens.shape + previous_hidden_states = previous_hidden_states.expand( + block_size, -1).unsqueeze(0) + return super().execute_model( + model_input=model_input, + kv_caches=kv_caches, + previous_hidden_states=previous_hidden_states, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + ) \ No newline at end of file diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 93419765e2c44..6231e431925a4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -23,6 +23,10 @@ if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner +elif current_platform.is_hpu(): + from vllm.spec_decode.hpu_draft_model_runner import (HPUTP1DraftModelRunner + as + TP1DraftModelRunner) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7e9b2bd13b48a..154d64068803e 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -354,6 +354,14 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) + + # sampler property will be used by spec_decode_worker + @property + def sampler(self): + return self.model.sampler + class PreparePromptMetadata(NamedTuple): input_tokens: torch.Tensor @@ -514,6 +522,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def __init__( self, vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, return_hidden_states: bool = False, ): @@ -539,13 +548,18 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = self.cache_config.cache_dtype + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) + ) if needs_attn_backend else None # Lazy initialization self.lora_manager: LRUCacheWorkerLoRAManager = None @@ -1879,6 +1893,7 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, + previous_hidden_states: Optional[torch.Tensor] = None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( @@ -1922,6 +1937,9 @@ def execute_model( "lora_mask": lora_mask, **(model_input.multi_modal_kwargs or {}), } + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) if htorch.utils.internal.is_lazy(): execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) @@ -1987,6 +2005,11 @@ def execute_model( real_batch_size=real_batch_size, is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states return [output] def shutdown_inc(self): diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 493f7a9fad098..1df078551cf67 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -22,7 +22,6 @@ from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.hpu_model_runner import HPUModelRunner -from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -44,7 +43,7 @@ def __init__( rank: int, distributed_init_method: str, is_driver_worker: bool = False, - model_runner_cls: Optional[Type[ModelRunnerBase]] = None, + model_runner_cls: Optional[Type[HPUModelRunner]] = None, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank @@ -60,8 +59,28 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: HPUModelRunner = HPUModelRunner( - vllm_config=vllm_config, is_driver_worker=is_driver_worker) + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[HPUModelRunner] = HPUModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + else: + ModelRunnerClass = HPUModelRunner + self.model_runner: HPUModelRunner = ModelRunnerClass( + vllm_config=vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[HPUCacheEngine]