From 3c590b1bf37b95ebd03d2cc378a82e63b0363637 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 13:59:19 +0530 Subject: [PATCH 01/34] initial changes to support EAGLE --- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/eagle.py | 85 +++++++++++++++++++++ vllm/sequence.py | 5 +- vllm/spec_decode/draft_model_runner.py | 6 ++ vllm/spec_decode/spec_decode_worker.py | 13 +++- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/eagle.py | 49 ++++++++++++ vllm/worker/model_runner.py | 4 + vllm/worker/worker.py | 2 +- vllm/worker/worker_base.py | 18 ++++- 11 files changed, 179 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/models/eagle.py create mode 100644 vllm/transformers_utils/configs/eagle.py diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 7df5b8fa64710..fc87f3bcb16a2 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -71,6 +71,7 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM") } diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py new file mode 100644 index 0000000000000..61e784cd828de --- /dev/null +++ b/vllm/model_executor/models/eagle.py @@ -0,0 +1,85 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.transformers_utils.configs.eagle import EAGLEConfig + + +class EAGLE(nn.Module): + + def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: + super().__init__() + self.config = config + + architectures = getattr(self.config.model, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + break + + self.model = model_cls(self.config.model, *args, **kwargs) + self.fc = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=False) + + self.token_map = None + + @property + def sampler(self): + return self.model.sampler + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + + tok_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.fc(torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + + hidden_states = self.model.model(input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + model_weights = [] + for name, loaded_weight in weights: + if name.startswith("fc."): + weight_loader = getattr(self.fc.weight, "weight_loader", + default_weight_loader) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights.append((name, loaded_weight)) + else: + model_weights.append((f"model.{name}", loaded_weight)) + + self.model.load_weights(model_weights) \ No newline at end of file diff --git a/vllm/sequence.py b/vllm/sequence.py index 72821ecea0f47..1b7497ed62732 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -872,6 +872,9 @@ class SamplerOutput: # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None + # Optional prefill hidden states from the model (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + def __getitem__(self, idx: int): return self.outputs[idx] @@ -994,7 +997,7 @@ class ExecuteModelRequest: # The number of requests in the running queue. running_queue_size: int = 0 # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None + previous_hidden_states: Union[HiddenStates, torch.Tensor, None] = None # The number of forward steps to run. num_steps: int = 1 # Finished request ids since last step. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 95071ecb6c8da..586f11ebb4145 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -234,6 +234,7 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: @@ -313,9 +314,13 @@ def execute_model( model_executable = self.model outputs: List[SamplerOutput] = [] + hidden_states = previous_hidden_states for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} + kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, @@ -324,6 +329,7 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, + **kwargs, ) # Compute the logits. diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 98960b88f719f..c91c3137a2d2f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -144,6 +144,11 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "eagle": + raise NotImplementedError( + "EAGLE does not support TP > 1 yet") + allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) @@ -450,8 +455,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - if not skip_proposer: - self.proposer_worker.execute_model(execute_model_req) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 @@ -466,6 +469,10 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, else: self.previous_hidden_states.update( execute_model_req.seq_group_metadata_list, hidden_states) + + if not skip_proposer: + execute_model_req.previous_hidden_states = sampler_output.prefill_hidden_states + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) @@ -530,6 +537,8 @@ def _run_speculative_decoding_step( #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") + + execute_model_req.previous_hidden_states = None proposal_scores = self.scorer.score_proposals( execute_model_req, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 652505a892142..d8d07b230aefb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -8,7 +8,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, - RWConfig) + RWConfig, EAGLEConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -26,6 +26,7 @@ "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, + "eagle": EAGLEConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 51de11ca3e42a..c93f64e73efef 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig @@ -16,5 +17,6 @@ "RWConfig", "JAISConfig", "MedusaConfig", + "EAGLEConfig", "MLPSpeculatorConfig", ] diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py new file mode 100644 index 0000000000000..9436b9b8185f5 --- /dev/null +++ b/vllm/transformers_utils/configs/eagle.py @@ -0,0 +1,49 @@ +import os +from typing import Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class EAGLEConfig(PretrainedConfig): + model_type = "eagle" + + def __init__(self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + model_config = None if model is None else ( + AutoConfig.for_model(**model) if isinstance(model, dict) else model) + + for k,v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr(model_config, k): + setattr(model_config, k, v) + + self.model = model_config + + if self.model is None: + self.truncated_vocab_size = None + else: + self.truncated_vocab_size = self.vocab_size if truncated_vocab_size is None\ + else truncated_vocab_size + + if "architectures" not in kwargs: + kwargs["architectures"] = ["EAGLEModel"] + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "EAGLEConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) + + def __getattribute__(self, key): + try: + return super().__getattribute__(key) + except: + return getattr(self.model, key) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..65f91a79f8ab0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1383,6 +1383,10 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) + + prefill_hidden_states = hidden_or_intermediate_states.roll(shifts=1, dims=0) + prefill_hidden_states.masked_fill_((model_input.input_positions == 0).unsqueeze(-1), 0) + output.prefill_hidden_states = prefill_hidden_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f3c379d1aa34d..7739d001071e3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -79,7 +79,7 @@ def __init__( or (speculative_config.draft_model_config.model == model_config.model) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator"]) \ + not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 03e3857e23c4b..43c7ad2237685 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, +from vllm.sequence import (ExecuteModelRequest, HiddenStates, IntermediateTensors, SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -269,10 +269,20 @@ def execute_model( intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict()) + if isinstance(execute_model_req.previous_hidden_states, torch.Tensor): + kwargs = {"previous_hidden_states": execute_model_req.previous_hidden_states} + elif isinstance(execute_model_req.previous_hidden_states, HiddenStates): + kwargs = {"previous_hidden_states": execute_model_req.previous_hidden_states.hidden_states} + else: + kwargs = {} + output = self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) if not get_pp_group().is_last_rank: # output is IntermediateTensors From 5f5bed1b101474595543fdce740ef06d7125baa5 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 17:19:16 +0530 Subject: [PATCH 02/34] handling hidden_states in case of bonus tokens since EAGLE will need it --- vllm/sequence.py | 29 +++++++++++++++++++++++--- vllm/spec_decode/multi_step_worker.py | 8 ++++++- vllm/spec_decode/spec_decode_worker.py | 4 +++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 1b7497ed62732..a4ecefbab2b49 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -955,17 +955,24 @@ class HiddenStates: dimension of the hidden_states tensor""" def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor): + hidden_states: torch.Tensor, + bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) self.hidden_states: torch.Tensor = hidden_states + self.bonus_token_previous_hidden_states: Optional[torch.Tensor] = bonus_token_previous_hidden_states def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor) -> None: + hidden_states: torch.Tensor, + bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + # Adding dummy hidden_states to this to maintain same shape + self.bonus_token_previous_hidden_states = torch.cat([ + self.hidden_states, + hidden_states if bonus_token_previous_hidden_states is None else bonus_token_previous_hidden_states]) def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: @@ -975,8 +982,24 @@ def prune(self, # Batch contents changed - prune removed sequences. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] + if self.bonus_token_previous_hidden_states is not None: + self.bonus_token_previous_hidden_states = self.bonus_token_previous_hidden_states[index] self.seq_ids = seq_ids - + + def expand_with_bonus_tokens(self, + seq_with_bonus_token_in_last_step: set) -> None: + if self.bonus_token_previous_hidden_states is None or not seq_with_bonus_token_in_last_step: + return + + index = [] + for seq_id in self.seq_ids: + i = self.seq_ids.index(seq_id) + if seq_id in seq_with_bonus_token_in_last_step: + index.append(i + len(self.seq_ids)) + index.append(i) + + self.hidden_states = torch.cat([self.hidden_states, + self.bonus_token_previous_hidden_states])[index] @dataclass class ExecuteModelRequest: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 91689324557b5..1514264b76170 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,7 +4,7 @@ import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceData, SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -153,6 +153,12 @@ def _expand_execute_model_request( updated_execute_model_req.seq_group_metadata_list =\ updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + return updated_execute_model_req, indices_of_original_sequence_groups @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c91c3137a2d2f..99604424555e4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -644,10 +644,12 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) + bonus_token_previous_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates(seq_group_metadata_list, - hidden_states) + hidden_states, + bonus_token_previous_hidden_states) return accepted_token_ids, logprobs From 023e72d7dafadd611e0559988d2bcdf364f1f683 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 17:59:12 +0530 Subject: [PATCH 03/34] enabling CUDA graph --- vllm/spec_decode/draft_model_runner.py | 7 ++++- vllm/worker/model_runner.py | 36 +++++++++++++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 586f11ebb4145..e76e1bccf476f 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -310,11 +310,16 @@ def execute_model( graph_batch_size = model_input.input_tokens.shape[0] model_executable = (self.graph_runners[model_input.virtual_engine] [graph_batch_size]) + hidden_states = torch.cat([previous_hidden_states, + torch.empty([graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:]], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device)]) else: model_executable = self.model + hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] - hidden_states = previous_hidden_states for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 65f91a79f8ab0..5e8ef06d3c7e9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import dataclasses import gc +import inspect import time import warnings import weakref @@ -1035,6 +1036,13 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + previous_hidden_states = None + if "previous_hidden_states" in inspect.signature(self.execute_model).parameters: + previous_hidden_states = torch.empty( + [max_batch_size, self.model_config.get_hidden_size()], + dtype=self.model_config.dtype, + device=self.device) + intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1202,6 +1210,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "stream": graph_capture_context.stream } + if previous_hidden_states is not None: + capture_inputs["previous_hidden_states"] = previous_hidden_states[:batch_size] + if self.has_seqlen_agnostic: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ @@ -1440,11 +1451,11 @@ def capture( # Note one iteration is not enough for torch.jit.script for _ in range(_NUM_WARMUP_ITERS): self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) torch.cuda.synchronize() @@ -1453,11 +1464,11 @@ def capture( self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) if hidden_or_intermediate_states is not None: @@ -1534,6 +1545,11 @@ def forward( if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) + + if "previous_hidden_states" in self.input_buffers: + self.input_buffers["previous_hidden_states"].copy_( + kwargs["previous_hidden_states"], non_blocking=True) + if intermediate_tensors is not None: for key in intermediate_tensors.tensors: self.input_buffers[key].copy_(intermediate_tensors[key], From 8ac1570755cedb6ca5b53ae4919273346bd47e88 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 18:58:16 +0530 Subject: [PATCH 04/34] adding E2E test and formatting --- .../spec_decode/e2e/test_eagle_correctness.py | 224 ++++++++++++++++++ vllm/model_executor/models/eagle.py | 35 ++- vllm/sequence.py | 46 ++-- vllm/spec_decode/draft_model_runner.py | 18 +- vllm/spec_decode/multi_step_worker.py | 8 +- vllm/spec_decode/spec_decode_worker.py | 15 +- vllm/transformers_utils/config.py | 6 +- vllm/transformers_utils/configs/__init__.py | 2 +- vllm/transformers_utils/configs/eagle.py | 23 +- vllm/worker/model_runner.py | 24 +- vllm/worker/worker_base.py | 32 ++- 11 files changed, 348 insertions(+), 85 deletions(-) create mode 100644 tests/spec_decode/e2e/test_eagle_correctness.py diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py new file mode 100644 index 0000000000000..f15e33becd3fb --- /dev/null +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -0,0 +1,224 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, Medusa would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + +# main model +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 4 + +# precision +PRECISION = "float32" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 61e784cd828de..f0758b05e61d7 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -25,8 +25,8 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, - config.model.hidden_size, - bias=False) + config.model.hidden_size, + bias=False) self.token_map = None @@ -45,22 +45,21 @@ def forward( ) -> torch.Tensor: tok_embeds = self.model.model.embed_tokens(input_ids) - inputs_embeds = self.fc(torch.cat([tok_embeds, previous_hidden_states], dim=-1)) - - hidden_states = self.model.model(input_ids=None, - inputs_embeds=inputs_embeds, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors) + inputs_embeds = self.fc( + torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors) return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.model.compute_logits(hidden_states, - sampling_metadata) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.model.compute_logits(hidden_states, sampling_metadata) def sample( self, @@ -82,4 +81,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: model_weights.append((f"model.{name}", loaded_weight)) - self.model.load_weights(model_weights) \ No newline at end of file + self.model.load_weights(model_weights) diff --git a/vllm/sequence.py b/vllm/sequence.py index a4ecefbab2b49..9c4e8aef1dc5d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -872,7 +872,8 @@ class SamplerOutput: # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None - # Optional prefill hidden states from the model (used for models like EAGLE). + # Optional prefill hidden states from the model + # (used for models like EAGLE). prefill_hidden_states: Optional[torch.Tensor] = None def __getitem__(self, idx: int): @@ -954,17 +955,22 @@ class HiddenStates: seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor, - bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): + def __init__( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + hidden_states: torch.Tensor, + bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) self.hidden_states: torch.Tensor = hidden_states - self.bonus_token_previous_hidden_states: Optional[torch.Tensor] = bonus_token_previous_hidden_states + self.bonus_token_previous_hidden_states: Optional[ + torch.Tensor] = bonus_token_previous_hidden_states - def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor, - bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): + def update( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + hidden_states: torch.Tensor, + bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) @@ -972,7 +978,9 @@ def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], # Adding dummy hidden_states to this to maintain same shape self.bonus_token_previous_hidden_states = torch.cat([ self.hidden_states, - hidden_states if bonus_token_previous_hidden_states is None else bonus_token_previous_hidden_states]) + hidden_states if bonus_token_previous_hidden_states is None else + bonus_token_previous_hidden_states + ]) def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: @@ -983,12 +991,14 @@ def prune(self, index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] if self.bonus_token_previous_hidden_states is not None: - self.bonus_token_previous_hidden_states = self.bonus_token_previous_hidden_states[index] + self.bonus_token_previous_hidden_states = self\ + .bonus_token_previous_hidden_states[index] self.seq_ids = seq_ids - - def expand_with_bonus_tokens(self, - seq_with_bonus_token_in_last_step: set) -> None: - if self.bonus_token_previous_hidden_states is None or not seq_with_bonus_token_in_last_step: + + def expand_with_bonus_tokens( + self, seq_with_bonus_token_in_last_step: set) -> None: + if self.bonus_token_previous_hidden_states is None \ + or not seq_with_bonus_token_in_last_step: return index = [] @@ -997,9 +1007,11 @@ def expand_with_bonus_tokens(self, if seq_id in seq_with_bonus_token_in_last_step: index.append(i + len(self.seq_ids)) index.append(i) - - self.hidden_states = torch.cat([self.hidden_states, - self.bonus_token_previous_hidden_states])[index] + + self.hidden_states = torch.cat( + [self.hidden_states, + self.bonus_token_previous_hidden_states])[index] + @dataclass class ExecuteModelRequest: diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index e76e1bccf476f..dddbb579409b5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -310,11 +310,19 @@ def execute_model( graph_batch_size = model_input.input_tokens.shape[0] model_executable = (self.graph_runners[model_input.virtual_engine] [graph_batch_size]) - hidden_states = torch.cat([previous_hidden_states, - torch.empty([graph_batch_size - previous_hidden_states.shape[0], - *previous_hidden_states.shape[1:]], - dtype=previous_hidden_states.dtype, - device=previous_hidden_states.device)]) + + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None else: model_executable = self.model hidden_states = previous_hidden_states diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 1514264b76170..a239fc6f5bb5f 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,8 +4,8 @@ import torch -from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, + SequenceData, SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) @@ -153,12 +153,12 @@ def _expand_execute_model_request( updated_execute_model_req.seq_group_metadata_list =\ updated_seq_group_metadata_list - + if isinstance(updated_execute_model_req.previous_hidden_states, HiddenStates): updated_execute_model_req.previous_hidden_states\ .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) - + return updated_execute_model_req, indices_of_original_sequence_groups @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 99604424555e4..244d2f708c8f5 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -145,7 +145,7 @@ def create_worker( "model_runner_cls"] = TP1DraftModelRunner else: if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "eagle": + "model_config"].hf_config.model_type == "eagle": raise NotImplementedError( "EAGLE does not support TP > 1 yet") @@ -469,9 +469,10 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, else: self.previous_hidden_states.update( execute_model_req.seq_group_metadata_list, hidden_states) - + if not skip_proposer: - execute_model_req.previous_hidden_states = sampler_output.prefill_hidden_states + execute_model_req.previous_hidden_states = sampler_output\ + .prefill_hidden_states self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( @@ -537,7 +538,7 @@ def _run_speculative_decoding_step( #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - + execute_model_req.previous_hidden_states = None proposal_scores = self.scorer.score_proposals( @@ -647,9 +648,9 @@ def _verify_tokens( bonus_token_previous_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates(seq_group_metadata_list, - hidden_states, - bonus_token_previous_hidden_states) + self.previous_hidden_states = HiddenStates( + seq_group_metadata_list, hidden_states, + bonus_token_previous_hidden_states) return accepted_token_ids, logprobs diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d8d07b230aefb..4d72c35342728 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,9 +6,9 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - RWConfig, EAGLEConfig) + EAGLEConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, RWConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index c93f64e73efef..67f85c16f142f 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,12 +1,12 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.eagle import EAGLEConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig -from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 9436b9b8185f5..9215ecd5228f9 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -12,24 +12,25 @@ def __init__(self, truncated_vocab_size: Optional[int] = None, **kwargs): - model_config = None if model is None else ( - AutoConfig.for_model(**model) if isinstance(model, dict) else model) - - for k,v in kwargs.items(): - if k != "architectures" and k != "model_type" and hasattr(model_config, k): + model_config = None if model is None else (AutoConfig.for_model( + **model) if isinstance(model, dict) else model) + + for k, v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr( + model_config, k): setattr(model_config, k, v) - + self.model = model_config if self.model is None: self.truncated_vocab_size = None else: - self.truncated_vocab_size = self.vocab_size if truncated_vocab_size is None\ - else truncated_vocab_size + self.truncated_vocab_size = self.vocab_size if \ + truncated_vocab_size is None else truncated_vocab_size if "architectures" not in kwargs: kwargs["architectures"] = ["EAGLEModel"] - + super().__init__(**kwargs) @classmethod @@ -41,9 +42,9 @@ def from_pretrained( config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs) - + def __getattribute__(self, key): try: return super().__getattribute__(key) - except: + except AttributeError: return getattr(self.model, key) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5e8ef06d3c7e9..7c5591a147d54 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1037,12 +1037,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() previous_hidden_states = None - if "previous_hidden_states" in inspect.signature(self.execute_model).parameters: + if "previous_hidden_states" in inspect.signature( + self.execute_model).parameters: previous_hidden_states = torch.empty( - [max_batch_size, self.model_config.get_hidden_size()], + [max_batch_size, + self.model_config.get_hidden_size()], dtype=self.model_config.dtype, device=self.device) - + intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1211,7 +1213,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: graph_capture_context.stream } if previous_hidden_states is not None: - capture_inputs["previous_hidden_states"] = previous_hidden_states[:batch_size] + capture_inputs[ + "previous_hidden_states"] = previous_hidden_states[: + batch_size] if self.has_seqlen_agnostic: # Only used by Mamba-based models CUDA graph atm (Jamba) @@ -1394,9 +1398,11 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) - - prefill_hidden_states = hidden_or_intermediate_states.roll(shifts=1, dims=0) - prefill_hidden_states.masked_fill_((model_input.input_positions == 0).unsqueeze(-1), 0) + prefill_hidden_states = hidden_or_intermediate_states.roll( + shifts=1, dims=0) + assert isinstance(model_input.input_positions, torch.Tensor) + prefill_hidden_states.masked_fill_( + (model_input.input_positions == 0).unsqueeze(-1), 0) output.prefill_hidden_states = prefill_hidden_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] @@ -1545,11 +1551,11 @@ def forward( if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) - + if "previous_hidden_states" in self.input_buffers: self.input_buffers["previous_hidden_states"].copy_( kwargs["previous_hidden_states"], non_blocking=True) - + if intermediate_tensors is not None: for key in intermediate_tensors.tensors: self.input_buffers[key].copy_(intermediate_tensors[key], diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 43c7ad2237685..ca13263d5f7e0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -10,8 +10,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, HiddenStates, IntermediateTensors, - SamplerOutput) +from vllm.sequence import (ExecuteModelRequest, HiddenStates, + IntermediateTensors, SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -239,10 +239,26 @@ def execute_model( execute_model_req.finished_requests_ids)) num_steps = execute_model_req.num_steps + if isinstance(execute_model_req.previous_hidden_states, + torch.Tensor): + kwargs = { + "previous_hidden_states": + execute_model_req.previous_hidden_states + } + elif isinstance(execute_model_req.previous_hidden_states, + HiddenStates): + kwargs = { + "previous_hidden_states": + execute_model_req.previous_hidden_states.hidden_states + } + else: + kwargs = {} + if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update( model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) broadcast_data["num_steps"] = num_steps broadcast_tensor_dict(broadcast_data, src=0) else: @@ -258,6 +274,8 @@ def execute_model( self.model_runner. make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + kwargs = broadcast_data + self.execute_worker(worker_input) # If there is no input, we don't need to execute the model. @@ -269,16 +287,10 @@ def execute_model( intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict()) - if isinstance(execute_model_req.previous_hidden_states, torch.Tensor): - kwargs = {"previous_hidden_states": execute_model_req.previous_hidden_states} - elif isinstance(execute_model_req.previous_hidden_states, HiddenStates): - kwargs = {"previous_hidden_states": execute_model_req.previous_hidden_states.hidden_states} - else: - kwargs = {} - output = self.model_runner.execute_model( model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, intermediate_tensors=intermediate_tensors, num_steps=num_steps, **kwargs, From b379948b87bd9a7cf9c8f7436f6f61d483ad57f8 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 19:30:44 +0530 Subject: [PATCH 05/34] minor bug fix in graph capture --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7c5591a147d54..1797bf0a34388 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1038,7 +1038,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: block_tables = torch.from_numpy(self.graph_block_tables).cuda() previous_hidden_states = None if "previous_hidden_states" in inspect.signature( - self.execute_model).parameters: + self.model.forward).parameters: previous_hidden_states = torch.empty( [max_batch_size, self.model_config.get_hidden_size()], From aef9c0003432f5ed7fcc61526036e93c9c84c8d7 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 20:38:17 +0530 Subject: [PATCH 06/34] fixing broadcasting of hidden states in distributed worker --- vllm/worker/worker_base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ca13263d5f7e0..dd8cd6f45de2e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -274,7 +274,12 @@ def execute_model( self.model_runner. make_model_input_from_broadcasted_tensor_dict(broadcast_data)) - kwargs = broadcast_data + if "previous_hidden_states" in broadcast_data: + kwargs = { + "previous_hidden_states": broadcast_data["previous_hidden_states"] + } + else: + kwargs = {} self.execute_worker(worker_input) From c8d63bdf9fddfa2925d24ce322e6b2075625e7a7 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 26 Jul 2024 20:39:15 +0530 Subject: [PATCH 07/34] formatting --- vllm/worker/worker_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dd8cd6f45de2e..09c2f9ce42263 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -276,7 +276,8 @@ def execute_model( if "previous_hidden_states" in broadcast_data: kwargs = { - "previous_hidden_states": broadcast_data["previous_hidden_states"] + "previous_hidden_states": + broadcast_data["previous_hidden_states"] } else: kwargs = {} From 1a0aa60e6aa08696b7a929a13f13093d09f06436 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Sat, 27 Jul 2024 11:23:27 +0530 Subject: [PATCH 08/34] formatting --- vllm/transformers_utils/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 25dd23af8710f..493aadc91dafb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -8,7 +8,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, RWConfig) + MPTConfig, NemotronConfig, + RWConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig From b1f05acac7a31ff90c719c75b961522b0980979d Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 31 Jul 2024 12:03:07 +0530 Subject: [PATCH 09/34] Masking position=0 in inputs for EAGLE --- tests/spec_decode/e2e/test_eagle_correctness.py | 2 +- vllm/model_executor/models/eagle.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f15e33becd3fb..5563e10ee1981 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -15,7 +15,7 @@ * Test greedy equality under preemption * Test greedy equality under various number of speculative tokens. -With those tests, we can say at least, Medusa would not break the +With those tests, we can say at least, EAGLE would not break the correctess for the target model outputs. """ diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index f0758b05e61d7..cf0a38ca65f27 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -48,6 +48,8 @@ def forward( inputs_embeds = self.fc( torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + hidden_states = self.model.model( input_ids=None, inputs_embeds=inputs_embeds, From bdee07c42fe691d4f4e6d5c1890a387792849a60 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 31 Jul 2024 12:07:02 +0530 Subject: [PATCH 10/34] reformatting --- vllm/transformers_utils/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6627c19dd6e43..ce7abf04e2230 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -7,8 +7,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, + JAISConfig, MedusaConfig, + MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig) if VLLM_USE_MODELSCOPE: From 441374f1f25d0bde6ce81910b53441f5e00cf401 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 31 Jul 2024 13:37:10 +0530 Subject: [PATCH 11/34] Fixing the order of execution for scorer and proposer in non-driver worker --- vllm/spec_decode/spec_decode_worker.py | 25 ++++++++++++++++-------- vllm/transformers_utils/configs/eagle.py | 13 ++++++------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 6d6d8f6f0a786..6c3448973471b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -357,6 +357,9 @@ def execute_model( # communication to inform them. broadcast_dict = dict( num_lookahead_slots=num_lookahead_slots, + no_spec=num_lookahead_slots == 0 + or len(execute_model_req.seq_group_metadata_list) == 0 + or disable_all_speculation, disable_all_speculation=disable_all_speculation, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -504,15 +507,21 @@ def _run_non_driver_rank(self) -> bool: return False num_lookahead_slots = data["num_lookahead_slots"] - # Even if num_lookahead_slots is zero, we want to run the proposer model - # as it may have KV. - # - # We run the proposer once per lookahead slot. In the future we should - # delegate how many times it runs to the proposer. - for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model() + if data["no_spec"]: + self.scorer_worker.execute_model() + + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + if not data["no_spec"]: + self.scorer_worker.execute_model() - self.scorer_worker.execute_model() return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 9215ecd5228f9..b357a785e4dc4 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -25,7 +25,7 @@ def __init__(self, if self.model is None: self.truncated_vocab_size = None else: - self.truncated_vocab_size = self.vocab_size if \ + self.truncated_vocab_size = self.model.vocab_size if \ truncated_vocab_size is None else truncated_vocab_size if "architectures" not in kwargs: @@ -33,6 +33,11 @@ def __init__(self, super().__init__(**kwargs) + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + @classmethod def from_pretrained( cls, @@ -42,9 +47,3 @@ def from_pretrained( config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs) - - def __getattribute__(self, key): - try: - return super().__getattribute__(key) - except AttributeError: - return getattr(self.model, key) From 0d1cbae36975bf134fbe236e73cf608e6add891e Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Thu, 1 Aug 2024 09:29:45 +0530 Subject: [PATCH 12/34] Adding hidden state propagation to _execute_model_spmd --- vllm/worker/worker_base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 09c2f9ce42263..824a020ae699e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -333,9 +333,27 @@ def _execute_model_spmd( if worker_input.num_seq_groups == 0: return [] + if isinstance(execute_model_req.previous_hidden_states, + torch.Tensor): + kwargs = { + "previous_hidden_states": + execute_model_req.previous_hidden_states + } + elif isinstance(execute_model_req.previous_hidden_states, + HiddenStates): + kwargs = { + "previous_hidden_states": + execute_model_req.previous_hidden_states.hidden_states + } + else: + kwargs = {} + return self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + **kwargs, + ) class WorkerWrapperBase: From b60384a2df603ca44e8b34124f8250c50a99a262 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Thu, 1 Aug 2024 09:31:46 +0530 Subject: [PATCH 13/34] Adding CUDA graph tests for medusa and eagle. Renaming mlp to medusa or eagle where required. --- .../spec_decode/e2e/test_eagle_correctness.py | 52 ++++++++++++++++--- .../e2e/test_medusa_correctness.py | 52 ++++++++++++++++--- 2 files changed, 92 insertions(+), 12 deletions(-) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 5563e10ee1981..d867d2ddfe80a 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -68,7 +68,47 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, +def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, @@ -114,7 +154,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, +def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): @@ -163,9 +203,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_different_k(baseline_llm_generator, test_llm_generator, +def test_eagle_different_k(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality + """Verify that eagle speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ run_greedy_equality_correctness_test(baseline_llm_generator, @@ -206,9 +246,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, +def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality + """Verify that eagle speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7e4a6cc62d02b..672d50b3ce1f2 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -70,7 +70,47 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, +def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, @@ -116,7 +156,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, +def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): @@ -165,9 +205,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_different_k(baseline_llm_generator, test_llm_generator, +def test_medusa_different_k(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality + """Verify that medusa speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ run_greedy_equality_correctness_test(baseline_llm_generator, @@ -208,9 +248,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, +def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality + """Verify that medusa speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ From 7b6a0e6d180a019faf55e98afcda1d887c7d2499 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Thu, 1 Aug 2024 10:55:53 +0530 Subject: [PATCH 14/34] Moving hidden states shift to spec_decode_worker --- vllm/sequence.py | 14 +++++++------- vllm/spec_decode/spec_decode_worker.py | 14 ++++++++++++-- vllm/worker/model_runner.py | 7 +------ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 524585d389e0e..dc330a31fbfd2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -951,7 +951,7 @@ def __init__( assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) self.hidden_states: torch.Tensor = hidden_states - self.bonus_token_previous_hidden_states: Optional[ + self.last_non_bonus_token_hidden_states: Optional[ torch.Tensor] = bonus_token_previous_hidden_states def update( @@ -964,7 +964,7 @@ def update( self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) # Adding dummy hidden_states to this to maintain same shape - self.bonus_token_previous_hidden_states = torch.cat([ + self.last_non_bonus_token_hidden_states = torch.cat([ self.hidden_states, hidden_states if bonus_token_previous_hidden_states is None else bonus_token_previous_hidden_states @@ -978,14 +978,14 @@ def prune(self, # Batch contents changed - prune removed sequences. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] - if self.bonus_token_previous_hidden_states is not None: - self.bonus_token_previous_hidden_states = self\ - .bonus_token_previous_hidden_states[index] + if self.last_non_bonus_token_hidden_states is not None: + self.last_non_bonus_token_hidden_states = self\ + .last_non_bonus_token_hidden_states[index] self.seq_ids = seq_ids def expand_with_bonus_tokens( self, seq_with_bonus_token_in_last_step: set) -> None: - if self.bonus_token_previous_hidden_states is None \ + if self.last_non_bonus_token_hidden_states is None \ or not seq_with_bonus_token_in_last_step: return @@ -998,7 +998,7 @@ def expand_with_bonus_tokens( self.hidden_states = torch.cat( [self.hidden_states, - self.bonus_token_previous_hidden_states])[index] + self.last_non_bonus_token_hidden_states])[index] @dataclass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 6c3448973471b..10fdea104384f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -477,8 +477,16 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, execute_model_req.seq_group_metadata_list, hidden_states) if not skip_proposer: - execute_model_req.previous_hidden_states = sampler_output\ - .prefill_hidden_states + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + if sampler_output.prefill_hidden_states is not None: + execute_model_req.previous_hidden_states = sampler_output\ + .prefill_hidden_states.roll(shifts=1, dims=0) + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( @@ -507,6 +515,8 @@ def _run_non_driver_rank(self) -> bool: return False num_lookahead_slots = data["num_lookahead_slots"] + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. if data["no_spec"]: self.scorer_worker.execute_model() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b67d194f9a2db..bfab64468d2e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1402,12 +1402,7 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) - prefill_hidden_states = hidden_or_intermediate_states.roll( - shifts=1, dims=0) - assert isinstance(model_input.input_positions, torch.Tensor) - prefill_hidden_states.masked_fill_( - (model_input.input_positions == 0).unsqueeze(-1), 0) - output.prefill_hidden_states = prefill_hidden_states + output.prefill_hidden_states = hidden_or_intermediate_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: From 9d806b33ebec6b1b5816d5d3fbb3051602ff3c8a Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Thu, 1 Aug 2024 11:00:11 +0530 Subject: [PATCH 15/34] formatting --- .../spec_decode/e2e/test_eagle_correctness.py | 21 +++++++++++-------- .../e2e/test_medusa_correctness.py | 21 +++++++++++-------- vllm/worker/worker_base.py | 3 +-- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index d867d2ddfe80a..97231e10ec8d7 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -68,8 +68,9 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_eagle_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -108,8 +109,10 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -155,9 +158,9 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size: int, - output_len: int): + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ @@ -204,7 +207,7 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("seed", [1]) def test_eagle_different_k(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): + batch_size: int, output_len: int): """Verify that eagle speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ @@ -247,7 +250,7 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("seed", [1]) def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): + batch_size: int, output_len: int): """Verify that eagle speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 672d50b3ce1f2..13a6c106ebd03 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -70,8 +70,9 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_medusa_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -110,8 +111,10 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generato ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -157,9 +160,9 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator, test_llm_generato @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size: int, - output_len: int): + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ @@ -206,7 +209,7 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("seed", [1]) def test_medusa_different_k(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): + batch_size: int, output_len: int): """Verify that medusa speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ @@ -249,7 +252,7 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("seed", [1]) def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): + batch_size: int, output_len: int): """Verify that medusa speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 824a020ae699e..cfc0a8051d98b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -333,8 +333,7 @@ def _execute_model_spmd( if worker_input.num_seq_groups == 0: return [] - if isinstance(execute_model_req.previous_hidden_states, - torch.Tensor): + if isinstance(execute_model_req.previous_hidden_states, torch.Tensor): kwargs = { "previous_hidden_states": execute_model_req.previous_hidden_states From 8db174fae40a75fc7d2e74dbf7f8cbe97db4b600 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 2 Aug 2024 11:29:54 +0530 Subject: [PATCH 16/34] Adding vocab truncation to EAGLE --- vllm/model_executor/models/eagle.py | 66 +++++++++++++++++++++++--- vllm/model_executor/models/medusa.py | 5 ++ vllm/sequence.py | 10 ++-- vllm/spec_decode/spec_decode_worker.py | 4 +- 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index cf0a38ca65f27..2737c54d782b8 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -4,6 +4,9 @@ import torch.nn as nn from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -28,6 +31,27 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: config.model.hidden_size, bias=False) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. self.token_map = None @property @@ -61,7 +85,19 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype) + + logits[..., self.token_map] = _logits + + return logits def sample( self, @@ -72,15 +108,33 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - model_weights = [] + model_weights = {} for name, loaded_weight in weights: - if name.startswith("fc."): + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name.startswith("fc."): weight_loader = getattr(self.fc.weight, "weight_loader", default_weight_loader) weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("model.lm_head.") or name.startswith( + "model.model."): + model_weights[name.split("model.", 1)[-1]] = loaded_weight elif name.startswith("lm_head.") or name.startswith("model."): - model_weights.append((name, loaded_weight)) + model_weights[name] = loaded_weight else: - model_weights.append((f"model.{name}", loaded_weight)) + model_weights[f"model.{name}"] = loaded_weight + + lm_head_weight = model_weights.pop("lm_head.weight") + + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: + + lm_head_weight = lm_head_weight[self.token_map] + + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, lm_head_weight) - self.model.load_weights(model_weights) + self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 6453d0cb25c91..4a908e9cd74d5 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -57,6 +57,11 @@ def __init__(self, config: MedusaConfig, **_) -> None: self.truncated_vocab_size, logit_scale) + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: diff --git a/vllm/sequence.py b/vllm/sequence.py index dc330a31fbfd2..80a6bc110cf4b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -947,18 +947,18 @@ def __init__( self, seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor, - bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): + last_non_bonus_token_hidden_states: Optional[torch.Tensor] = None): assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) self.hidden_states: torch.Tensor = hidden_states self.last_non_bonus_token_hidden_states: Optional[ - torch.Tensor] = bonus_token_previous_hidden_states + torch.Tensor] = last_non_bonus_token_hidden_states def update( self, seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor, - bonus_token_previous_hidden_states: Optional[torch.Tensor] = None): + last_non_bonus_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) @@ -966,8 +966,8 @@ def update( # Adding dummy hidden_states to this to maintain same shape self.last_non_bonus_token_hidden_states = torch.cat([ self.hidden_states, - hidden_states if bonus_token_previous_hidden_states is None else - bonus_token_previous_hidden_states + hidden_states if last_non_bonus_token_hidden_states is None else + last_non_bonus_token_hidden_states ]) def prune(self, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 10fdea104384f..29eb84b46fe19 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -661,12 +661,12 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) - bonus_token_previous_hidden_states = hidden_states[:, -2] # b x d + last_non_bonus_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( seq_group_metadata_list, hidden_states, - bonus_token_previous_hidden_states) + last_non_bonus_token_hidden_states) return accepted_token_ids, logprobs From b6b05482c4e69c01986aeb3f8922e6dd12099c86 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 2 Aug 2024 13:43:06 +0530 Subject: [PATCH 17/34] Minor changes and fixes. Adding expand model request and hidden states sync test --- .../spec_decode/e2e/test_eagle_correctness.py | 3 +- tests/spec_decode/test_multi_step_worker.py | 55 ++++++++++++++++++- vllm/sequence.py | 15 +++-- vllm/worker/worker_base.py | 54 +++++++----------- 4 files changed, 85 insertions(+), 42 deletions(-) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 97231e10ec8d7..6a1819e990f44 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -113,7 +113,8 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify greedy equality with different batch size.""" + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, batch_size, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 442e40f07f0bb..40808c06962d1 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -6,7 +6,8 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, + SamplerOutput, get_all_seq_ids) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer @@ -690,3 +691,55 @@ def test_use_draft_model_runner_advance_step(): worker.execute_model(execute_model_req=execute_model_req) call_args_list = worker.model_runner._gpu_advance_step.call_args_list assert len(call_args_list) == 1 + + +@torch.inference_mode() +def test_expand_execute_model_request_sync_with_expand_hidden_states(): + """ + In this test we verify that the logic for expanding the + seq_group_metadata_list remains in sync with the expansion logic of + the HiddenStates in _expand_execute_model_request. + """ + k = 5 + batch_size = 16 + # seq_groups_to_merge = {3: 2, 6: 5, 7: 5, 13: 12, 14: 12, 15: 12} + seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + # merged_seq_group_metadata = {} + # for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # if i in seq_groups_to_merge: + # dst = merged_seq_group_metadata.get( + # seq_groups_to_merge[i], + # seq_group_metadata_list[seq_groups_to_merge[i]]) + # src = seq_group_metadata + # prompt_token_ids = next(iter(dst.seq_data.values()))\ + # .prompt_token_ids + # for seq_id, seq in src.seq_data.items(): + # dst.seq_data[seq_id] = SequenceData(prompt_token_ids, + # seq.output_token_ids) + # dst.block_tables[seq_id] = src.block_tables[seq_id] + # merged_seq_group_metadata[seq_groups_to_merge[i]] = dst + # else: + # merged_seq_group_metadata[i] = seq_group_metadata + # seq_group_metadata_list = [merged_seq_group_metadata[k] for k in sorted( + # merged_seq_group_metadata.keys())] + + execute_model_request = ExecuteModelRequest( + seq_group_metadata_list, + previous_hidden_states=HiddenStates( + seq_group_metadata_list, torch.arange(batch_size), + torch.arange(batch_size, 2 * batch_size))) + + expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\ + _expand_execute_model_request(execute_model_request, + seq_with_bonus_token_in_last_step) + + all_seq_ids = torch.tensor( + get_all_seq_ids( + expanded_execute_model_request.seq_group_metadata_list)) + ref_expanded_hidden_states = all_seq_ids + batch_size + ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size + + assert (ref_expanded_hidden_states == expanded_execute_model_request. + previous_hidden_states.hidden_states).all().item() diff --git a/vllm/sequence.py b/vllm/sequence.py index 80a6bc110cf4b..3142c31221ddc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -963,12 +963,15 @@ def update( assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - # Adding dummy hidden_states to this to maintain same shape - self.last_non_bonus_token_hidden_states = torch.cat([ - self.hidden_states, - hidden_states if last_non_bonus_token_hidden_states is None else - last_non_bonus_token_hidden_states - ]) + + if self.last_non_bonus_token_hidden_states is not None: + # Adding dummy hidden_states to this to maintain same shape + self.last_non_bonus_token_hidden_states = torch.cat([ + self.last_non_bonus_token_hidden_states, + torch.zeros_like(hidden_states) + if last_non_bonus_token_hidden_states is None else + last_non_bonus_token_hidden_states + ]) def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3ceb45b400940..3fed9551913e6 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -239,20 +239,7 @@ def execute_model( execute_model_req.finished_requests_ids)) num_steps = execute_model_req.num_steps - if isinstance(execute_model_req.previous_hidden_states, - torch.Tensor): - kwargs = { - "previous_hidden_states": - execute_model_req.previous_hidden_states - } - elif isinstance(execute_model_req.previous_hidden_states, - HiddenStates): - kwargs = { - "previous_hidden_states": - execute_model_req.previous_hidden_states.hidden_states - } - else: - kwargs = {} + kwargs = extract_previous_hidden_states(execute_model_req) if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() @@ -274,13 +261,7 @@ def execute_model( self.model_runner. make_model_input_from_broadcasted_tensor_dict(broadcast_data)) - if "previous_hidden_states" in broadcast_data: - kwargs = { - "previous_hidden_states": - broadcast_data["previous_hidden_states"] - } - else: - kwargs = {} + kwargs = extract_previous_hidden_states(broadcast_data) self.execute_worker(worker_input) @@ -335,19 +316,7 @@ def _execute_model_spmd( if worker_input.num_seq_groups == 0: return [] - if isinstance(execute_model_req.previous_hidden_states, torch.Tensor): - kwargs = { - "previous_hidden_states": - execute_model_req.previous_hidden_states - } - elif isinstance(execute_model_req.previous_hidden_states, - HiddenStates): - kwargs = { - "previous_hidden_states": - execute_model_req.previous_hidden_states.hidden_states - } - else: - kwargs = {} + kwargs = extract_previous_hidden_states(execute_model_req) return self.model_runner.execute_model( model_input=model_input, @@ -428,3 +397,20 @@ def execute_method(self, method, *args, **kwargs): "This might cause deadlock in distributed execution.") logger.exception(msg) raise e + + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, dict]) -> dict: + output = {} + + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + else: + if isinstance(data.previous_hidden_states, torch.Tensor): + output["previous_hidden_states"] = data.previous_hidden_states + elif isinstance(data.previous_hidden_states, HiddenStates): + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + + return output From eaa586c269122723b9e8d6e30e8fc23c313328bf Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Tue, 6 Aug 2024 07:06:25 +0000 Subject: [PATCH 18/34] Removing commented code and a minor comment fix --- .../e2e/test_medusa_correctness.py | 3 ++- tests/spec_decode/test_multi_step_worker.py | 21 +------------------ 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 13a6c106ebd03..de4b2ab796a3c 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -115,7 +115,8 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): - """Verify greedy equality with different batch size.""" + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, batch_size, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 40808c06962d1..4d61dc38af4da 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -702,29 +702,10 @@ def test_expand_execute_model_request_sync_with_expand_hidden_states(): """ k = 5 batch_size = 16 - # seq_groups_to_merge = {3: 2, 6: 5, 7: 5, 13: 12, 14: 12, 15: 12} seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] seq_group_metadata_list, _, _ = create_batch(batch_size, k) - # merged_seq_group_metadata = {} - # for i, seq_group_metadata in enumerate(seq_group_metadata_list): - # if i in seq_groups_to_merge: - # dst = merged_seq_group_metadata.get( - # seq_groups_to_merge[i], - # seq_group_metadata_list[seq_groups_to_merge[i]]) - # src = seq_group_metadata - # prompt_token_ids = next(iter(dst.seq_data.values()))\ - # .prompt_token_ids - # for seq_id, seq in src.seq_data.items(): - # dst.seq_data[seq_id] = SequenceData(prompt_token_ids, - # seq.output_token_ids) - # dst.block_tables[seq_id] = src.block_tables[seq_id] - # merged_seq_group_metadata[seq_groups_to_merge[i]] = dst - # else: - # merged_seq_group_metadata[i] = seq_group_metadata - # seq_group_metadata_list = [merged_seq_group_metadata[k] for k in sorted( - # merged_seq_group_metadata.keys())] - + execute_model_request = ExecuteModelRequest( seq_group_metadata_list, previous_hidden_states=HiddenStates( From 38e2b5c975c8fd416a315be8c1dfa41102db0c40 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Tue, 6 Aug 2024 07:25:53 +0000 Subject: [PATCH 19/34] formatting --- tests/spec_decode/test_multi_step_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 4d61dc38af4da..96ea5eaa6861c 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -705,7 +705,7 @@ def test_expand_execute_model_request_sync_with_expand_hidden_states(): seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] seq_group_metadata_list, _, _ = create_batch(batch_size, k) - + execute_model_request = ExecuteModelRequest( seq_group_metadata_list, previous_hidden_states=HiddenStates( From c5f8d15694211879f6d8313824b667a1837448b7 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 7 Aug 2024 14:43:24 +0530 Subject: [PATCH 20/34] adding comments to clarify compatibility of eagle checkpoint in eagle.py --- vllm/model_executor/models/eagle.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 2737c54d782b8..88e6c4d443306 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -108,6 +108,12 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d model_weights = {} for name, loaded_weight in weights: if name == "token_map": From 7f46c68da317b891442270d0ed0bc9a6408c9432 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 7 Aug 2024 18:28:17 +0530 Subject: [PATCH 21/34] fixing model_cls resolution in eagle --- vllm/model_executor/models/eagle.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 88e6c4d443306..fee37bb278ba0 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -7,8 +7,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.eagle import EAGLEConfig @@ -20,11 +20,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: super().__init__() self.config = config - architectures = getattr(self.config.model, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - break + model_cls, _ = get_model_architecture(self.config.model) self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, From 5e5d214b174a00b2d635babd4c32644370fe3ecc Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 7 Aug 2024 19:57:00 +0530 Subject: [PATCH 22/34] fixing model_cls resolution in eagle --- vllm/model_executor/models/eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index fee37bb278ba0..1b12a17499365 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -7,8 +7,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) -from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.eagle import EAGLEConfig @@ -20,7 +20,8 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: super().__init__() self.config = config - model_cls, _ = get_model_architecture(self.config.model) + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, From ad04e7f7e288481274685c0bf3370f65f0f5cc9a Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Tue, 13 Aug 2024 12:22:34 +0530 Subject: [PATCH 23/34] adding doctrings to EAGLE and Medusa models --- vllm/model_executor/models/eagle.py | 17 ++++++++++++++++- vllm/model_executor/models/medusa.py | 13 ++++++++++++- vllm/worker/model_runner.py | 3 +++ vllm/worker/worker_base.py | 3 +++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 1b12a17499365..0c882bf6b643c 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -15,6 +15,20 @@ class EAGLE(nn.Module): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) + but we do as HF implementation also does. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size.""" def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: super().__init__() @@ -48,7 +62,8 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed - # -up. + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. self.token_map = None @property diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 4a908e9cd74d5..dc26bad694fec 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -30,6 +30,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Medusa(nn.Module): + """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 + Reference implementation: https://github.com/FasterDecoding/Medusa + + Differences from reference implementation: + 1. Currently this only supports generating proposals from top-1 tokens. + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size.""" def __init__(self, config: MedusaConfig, **_) -> None: super().__init__() @@ -61,7 +71,8 @@ def __init__(self, config: MedusaConfig, **_) -> None: # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed - # -up. + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5e51a5eb3833b..4bb780cfe8a75 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1206,6 +1206,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + # Prepare dummy previous_hidden_states only if needed by the model. + # This is used by draft models such as EAGLE. previous_hidden_states = None if "previous_hidden_states" in inspect.signature( self.model.forward).parameters: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ac801c112a25d..5b7adbe04d906 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -404,6 +404,9 @@ def execute_method(self, method, *args, **kwargs): def extract_previous_hidden_states( data: Union[ExecuteModelRequest, dict]) -> dict: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls.""" output = {} if isinstance(data, dict): From 90bee1d9bf36ce1538fc939b4e5d6e907ad21331 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 14 Aug 2024 12:55:58 +0530 Subject: [PATCH 24/34] fixing hidden states handling in batch expansion --- vllm/spec_decode/batch_expansion.py | 86 ++++++++++++++++++-------- vllm/spec_decode/spec_decode_worker.py | 5 +- vllm/spec_decode/top1_proposer.py | 2 +- vllm/spec_decode/util.py | 20 +++++- 4 files changed, 81 insertions(+), 32 deletions(-) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 45eaeb51c5c0f..aa973391a3d93 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,5 +1,5 @@ from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -86,21 +86,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -143,10 +144,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -154,9 +156,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -174,23 +177,40 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -324,8 +344,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -350,24 +371,37 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 63a00139cc09d..acf77a7349eef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,9 +646,8 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 1a56497030280..28f7f7eb069ab 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c6223a97dba10..b85f2a6f70ac0 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -123,7 +123,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -169,7 +169,23 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, From 88c20e6122e1b07b47a2cf1677d23b98ffd48892 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 14 Aug 2024 15:58:18 +0530 Subject: [PATCH 25/34] making HiddenStates a dataclass and renaming last_non_bonus_hidden_states --- vllm/sequence.py | 48 ++++++++++++++------------ vllm/spec_decode/spec_decode_worker.py | 4 +-- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 75adbd292ae19..160635502e1bc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union, cast) @@ -1040,6 +1040,7 @@ def get_all_seq_ids_and_request_ids( return seq_ids, request_id_seq_ids_mapping +@dataclass class HiddenStates: """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from @@ -1047,35 +1048,38 @@ class HiddenStates: seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - - def __init__( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor, - last_non_bonus_token_hidden_states: Optional[torch.Tensor] = None): - assert len(seq_group_metadata_list) == len(hidden_states) + # The sequence group metadata list. + seq_group_metadata_list: InitVar[List[SequenceGroupMetadata]] + # Scorer hidden states of the last token accepted by the scorer. + hidden_states: torch.Tensor + # Scorer hidden states of the 2nd last token proposed by the proposer ( + # irrespective of whether it was accepted or not). Only used for cases when + # last proposed token is accepted (i.e., in case of bonus tokens). For the + # case of no bonus tokens, these are ignored. + second_last_token_hidden_states: Optional[torch.Tensor] = None + + def __post_init__(self, + seq_group_metadata_list: List[SequenceGroupMetadata]): + assert len(seq_group_metadata_list) == len(self.hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) - self.hidden_states: torch.Tensor = hidden_states - self.last_non_bonus_token_hidden_states: Optional[ - torch.Tensor] = last_non_bonus_token_hidden_states def update( self, seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor, - last_non_bonus_token_hidden_states: Optional[torch.Tensor] = None): + second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - if self.last_non_bonus_token_hidden_states is not None: + if self.second_last_token_hidden_states is not None: # Adding dummy hidden_states to this to maintain same shape - self.last_non_bonus_token_hidden_states = torch.cat([ - self.last_non_bonus_token_hidden_states, + self.second_last_token_hidden_states = torch.cat([ + self.second_last_token_hidden_states, torch.zeros_like(hidden_states) - if last_non_bonus_token_hidden_states is None else - last_non_bonus_token_hidden_states + if second_last_token_hidden_states is None else + second_last_token_hidden_states ]) def prune(self, @@ -1086,14 +1090,14 @@ def prune(self, # Batch contents changed - prune removed sequences. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] - if self.last_non_bonus_token_hidden_states is not None: - self.last_non_bonus_token_hidden_states = self\ - .last_non_bonus_token_hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self.seq_ids = seq_ids def expand_with_bonus_tokens( self, seq_with_bonus_token_in_last_step: set) -> None: - if self.last_non_bonus_token_hidden_states is None \ + if self.second_last_token_hidden_states is None \ or not seq_with_bonus_token_in_last_step: return @@ -1106,7 +1110,7 @@ def expand_with_bonus_tokens( self.hidden_states = torch.cat( [self.hidden_states, - self.last_non_bonus_token_hidden_states])[index] + self.second_last_token_hidden_states])[index] @dataclass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 1ae3a056de2c9..2e2e4685581b1 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -678,12 +678,12 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) - last_non_bonus_token_hidden_states = hidden_states[:, -2] # b x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( seq_group_metadata_list, hidden_states, - last_non_bonus_token_hidden_states) + second_last_token_hidden_states) return accepted_token_ids, logprobs From 1753d9ad5ef07349ffbe7ce84ce44b29679cb92d Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 14 Aug 2024 17:36:28 +0530 Subject: [PATCH 26/34] reformatting --- vllm/model_executor/models/eagle.py | 2 +- vllm/model_executor/models/medusa.py | 2 +- vllm/sequence.py | 14 ++++++-------- vllm/worker/model_runner.py | 2 +- vllm/worker/worker_base.py | 11 ++++++----- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 0c882bf6b643c..a5f78275e7ee5 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -62,7 +62,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed - # -up. By default, this is disabled and is only used if the EAGLE + # -up. By default, this is disabled and is only used if the EAGLE # checkpoint file has token_map tensor. self.token_map = None diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 5d4aa39a26ed2..6fdb091140d2f 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -71,7 +71,7 @@ def __init__(self, config: MedusaConfig, **_) -> None: # the draft model. Using smaller vocab size for draft, containing # only most frequent tokens reduces the speculation overhead. This # doesn't affect the acceptance rate much and thus gives more speed - # -up. By default, this is disabled and is only used if the EAGLE + # -up. By default, this is disabled and is only used if the EAGLE # checkpoint file has token_map tensor. self.token_map = None diff --git a/vllm/sequence.py b/vllm/sequence.py index 7ca9a7c5e4324..017c33c544269 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1078,16 +1078,15 @@ class HiddenStates: # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - def __post_init__(self, + def __post_init__(self, seq_group_metadata_list: List[SequenceGroupMetadata]): assert len(seq_group_metadata_list) == len(self.hidden_states) self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) - def update( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor, - second_last_token_hidden_states: Optional[torch.Tensor] = None): + def update(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + hidden_states: torch.Tensor, + second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) @@ -1129,8 +1128,7 @@ def expand_with_bonus_tokens( index.append(i) self.hidden_states = torch.cat( - [self.hidden_states, - self.second_last_token_hidden_states])[index] + [self.hidden_states, self.second_last_token_hidden_states])[index] @dataclass diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e36ea1b08fed3..4fe3324a1af92 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1210,7 +1210,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - + # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. previous_hidden_states = None diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e7a668525b092..3ea4038caaa81 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -220,8 +220,9 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, - Dict[str, torch.Tensor]]]: + self + ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -233,7 +234,7 @@ def _get_worker_input_from_broadcast( model_input = ( self.model_runner.make_model_input_from_broadcasted_tensor_dict( broadcast_data)) - + kwargs = extract_previous_hidden_states(broadcast_data) return model_input, worker_input, kwargs @@ -265,8 +266,8 @@ def _get_driver_input_and_broadcast( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, - Dict[str, torch.Tensor]]]: + ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Prepare the inputs to ModelRunner and workers. """ From 99484ae79538b1ebafc963691cec132a1584117a Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 16 Aug 2024 14:52:21 +0530 Subject: [PATCH 27/34] adding acceptance rate test for large output length --- tests/spec_decode/e2e/conftest.py | 23 ++++++---- tests/spec_decode/e2e/test_mlp_correctness.py | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index d0f91a63b2d6a..b9e68d87705a9 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ensure_all_accepted=ensure_all_accepted) -def run_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - temperature: float, - seeded: bool, - print_tokens: bool = False, - ensure_all_accepted: bool = False): +def run_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False, + expected_acceptance_rate: Optional[float] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero (or when temperature is > 0 and seeded). @@ -359,3 +361,6 @@ def run_equality_correctness_test(baseline_llm_generator, if ensure_all_accepted: assert acceptance_rate == 1.0 + + if expected_acceptance_rate is not None: + assert acceptance_rate >= expected_acceptance_rate - 1e-2 diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 25067e7a4262c..93b03acf050e8 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + }, +]) +@pytest.mark.parametrize("output_len", [2048]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify acceptance rate with different batch size and large output + length.""" + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=0.0, + seeded=True, + force_output_len=True, + expected_acceptance_rate=0.6) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ From 2e51385139de4cc2514dad1aa4133e529c83e8b9 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 16 Aug 2024 15:07:15 +0530 Subject: [PATCH 28/34] fixing hidden states manipulation for batch expansion --- vllm/spec_decode/batch_expansion.py | 86 ++++++++++++++++++-------- vllm/spec_decode/spec_decode_worker.py | 5 +- vllm/spec_decode/top1_proposer.py | 2 +- vllm/spec_decode/util.py | 20 +++++- 4 files changed, 81 insertions(+), 32 deletions(-) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 45eaeb51c5c0f..aa973391a3d93 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,5 +1,5 @@ from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -86,21 +86,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -143,10 +144,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -154,9 +156,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -174,23 +177,40 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -324,8 +344,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -350,24 +371,37 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 63a00139cc09d..acf77a7349eef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,9 +646,8 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 1a56497030280..28f7f7eb069ab 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c6223a97dba10..b85f2a6f70ac0 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -123,7 +123,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -169,7 +169,23 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, From d8bcff0b63ce540163de6add2c40e79ccc0eb89f Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 16 Aug 2024 16:32:17 +0530 Subject: [PATCH 29/34] print acceptance rate in spec decode tests --- tests/spec_decode/e2e/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b9e68d87705a9..a701f482b4ffb 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -359,6 +359,8 @@ def run_equality_correctness_test( print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + print(f'{acceptance_rate=}') + if ensure_all_accepted: assert acceptance_rate == 1.0 From 1654d4d7fac29151e5bc2e5913a7cf09d21b63d4 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Fri, 16 Aug 2024 17:00:53 +0530 Subject: [PATCH 30/34] Updating HiddenStates to handle prefill step as well --- tests/spec_decode/test_multi_step_worker.py | 2 +- vllm/sequence.py | 30 +++++---- vllm/spec_decode/spec_decode_worker.py | 68 +++++++++++++-------- vllm/worker/worker_base.py | 15 +++-- 4 files changed, 69 insertions(+), 46 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 96ea5eaa6861c..ada6c37d9af8d 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -709,7 +709,7 @@ def test_expand_execute_model_request_sync_with_expand_hidden_states(): execute_model_request = ExecuteModelRequest( seq_group_metadata_list, previous_hidden_states=HiddenStates( - seq_group_metadata_list, torch.arange(batch_size), + torch.arange(batch_size), seq_group_metadata_list, torch.arange(batch_size, 2 * batch_size))) expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\ diff --git a/vllm/sequence.py b/vllm/sequence.py index 16eef950a8f33..6a21d39a86ed2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1091,14 +1091,16 @@ def get_all_seq_ids_and_request_ids( class HiddenStates: """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from - the target model to the proposer model in the subsequent step. + the target model to the proposer model. seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - # The sequence group metadata list. - seq_group_metadata_list: InitVar[List[SequenceGroupMetadata]] - # Scorer hidden states of the last token accepted by the scorer. + # Scorer hidden states. For prefill step, it is used for hidden states of + # all tokens, whereas for decode step, it use used for last accepted tokens. hidden_states: torch.Tensor + # The sequence group metadata list. Only needed for decode step. + seq_group_metadata_list: InitVar[Optional[ + List[SequenceGroupMetadata]]] = None # Scorer hidden states of the 2nd last token proposed by the proposer ( # irrespective of whether it was accepted or not). Only used for cases when # last proposed token is accepted (i.e., in case of bonus tokens). For the @@ -1106,15 +1108,19 @@ class HiddenStates: second_last_token_hidden_states: Optional[torch.Tensor] = None def __post_init__(self, - seq_group_metadata_list: List[SequenceGroupMetadata]): - assert len(seq_group_metadata_list) == len(self.hidden_states) - self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) + seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None): + + if seq_group_metadata_list is not None: + assert len(seq_group_metadata_list) == len(self.hidden_states) + self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) def update(self, - seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation.""" + """Update hidden states from target model invocation. Only used for + decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) @@ -1130,7 +1136,7 @@ def update(self, def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids.""" + """Prune to provided list of sequence ids. Only used for decode steps""" seq_ids = get_all_seq_ids(seq_group_metadata_list) if seq_ids != self.seq_ids: # Batch contents changed - prune removed sequences. @@ -1143,6 +1149,8 @@ def prune(self, def expand_with_bonus_tokens( self, seq_with_bonus_token_in_last_step: set) -> None: + """Expand hidden states for sequences with bonus tokens. This is in + alignment with `MultiStepWorker._expand_execute_model_request`.""" if self.second_last_token_hidden_states is None \ or not seq_with_bonus_token_in_last_step: return @@ -1177,7 +1185,7 @@ class ExecuteModelRequest: # The number of requests in the running queue. running_queue_size: int = 0 # Optional hidden states from prior step. - previous_hidden_states: Union[HiddenStates, torch.Tensor, None] = None + previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 # Finished request ids since last step. diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f1b02585a6aba..2762b8388029f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -360,17 +360,34 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. + no_spec = num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation + # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. # This is required as if the number of draft model runs changes # dynamically, the non-driver workers won't know unless we perform a # communication to inform them. + + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. broadcast_dict = dict( num_lookahead_slots=num_lookahead_slots, - no_spec=num_lookahead_slots == 0 - or len(execute_model_req.seq_group_metadata_list) == 0 - or disable_all_speculation, + no_spec=no_spec, disable_all_speculation=disable_all_speculation, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -381,17 +398,7 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # Speculative decoding is disabled in the following cases: - # 1. Prefill phase: Speculative decoding is not - # used during the prefill phase. - # 2. Auto-disable enabled: The running queue size exceeds - # the specified threshold. - # 3. No request: There are no requests in the batch. - # In any of these cases, the proposer and scorer workers - # are called normally. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list - ) == 0 or disable_all_speculation: + if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) return self._run_speculative_decoding_step(execute_model_req, @@ -482,21 +489,18 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if hidden_states is not None: if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) else: self.previous_hidden_states.update( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) if not skip_proposer: - # For prefill step in proposer, we run the model for N-1 tokens - # because Nth token will be processed in the first decode step. For - # N-1 tokens, the input should be 0:N-1 hidden states which should - # be concatanated with 1:N token (since output of scorer has to be - # the input for proposer). Therefore, we shift the hidden states to - # align n-1th hidden state with nth token. - if sampler_output.prefill_hidden_states is not None: - execute_model_req.previous_hidden_states = sampler_output\ - .prefill_hidden_states.roll(shifts=1, dims=0) + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) self.proposer_worker.execute_model(execute_model_req) @@ -684,7 +688,7 @@ def _verify_tokens( hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( - seq_group_metadata_list, hidden_states, + hidden_states, seq_group_metadata_list, second_last_token_hidden_states) return accepted_token_ids, logprobs @@ -982,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b75f076cfe76e..c9d0375321d14 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -12,8 +12,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, HiddenStates, - IntermediateTensors, SamplerOutput) +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -466,14 +466,13 @@ def extract_previous_hidden_states( execute_model calls.""" output = {} + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. if isinstance(data, dict): if "previous_hidden_states" in data: output["previous_hidden_states"] = data["previous_hidden_states"] - else: - if isinstance(data.previous_hidden_states, torch.Tensor): - output["previous_hidden_states"] = data.previous_hidden_states - elif isinstance(data.previous_hidden_states, HiddenStates): - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states return output From 6954eadde2d5a40560ed2e2287763e44b07782ee Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Sat, 17 Aug 2024 10:26:59 +0530 Subject: [PATCH 31/34] changing expected acceptance rate for test --- tests/spec_decode/e2e/test_mlp_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 93b03acf050e8..c72e4595fd335 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -121,7 +121,7 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator, temperature=0.0, seeded=True, force_output_len=True, - expected_acceptance_rate=0.6) + expected_acceptance_rate=0.48) @pytest.mark.parametrize( From f906cefc43102fa05a92bf54a24da03aa7d78052 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Mon, 19 Aug 2024 10:51:38 +0530 Subject: [PATCH 32/34] formatting --- vllm/sequence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 56a2b4141843f..f6c4a5a50ffc0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1192,7 +1192,7 @@ class HiddenStates(msgspec.Struct, array_like=True, # last proposed token is accepted (i.e., in case of bonus tokens). For the # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - + _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): @@ -1227,8 +1227,8 @@ def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: """Prune to provided list of sequence ids. Only used for decode steps. """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence + # Currently this prunes all seq_ids not present in + # seq_group_metadata_list which might cause problems where a sequence # may be "paused" then "resumed" later. This should only prune sequences # which are confirmed to be aborted. seq_ids = get_all_seq_ids(seq_group_metadata_list) From 3febb95f56ba35546ec492451bb8c6035f883e48 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Tue, 20 Aug 2024 12:57:29 +0530 Subject: [PATCH 33/34] Fixing compatibility of `worker.multi_step_worker.MultiStepWorker` with `worker.worker_base.LocalOrDistributedWorkerBase` --- vllm/worker/multi_step_worker.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 6a6caba9371eb..2ed77dd698f5c 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +import torch from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -43,7 +45,7 @@ def __init__(self, *args, **kwargs): def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ @@ -85,7 +87,9 @@ def _get_driver_input_and_broadcast( broadcast_data.update(model_input.as_broadcastable_tensor_dict()) broadcast_tensor_dict(broadcast_data, src=0) - return model_input, worker_input + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} def _prepare_last_sampled_token_ids_for_tp_workers( self, @@ -130,7 +134,8 @@ def _prepare_last_sampled_token_ids_for_tp_workers( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: """ Depending on the current state of the request and multi step worker, this method may skip the normal _prepare_model_input and @@ -148,8 +153,8 @@ def prepare_input( return None virtual_engine = execute_model_req.virtual_engine - model_input, worker_input = self._get_driver_input_and_broadcast( - execute_model_req) + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) assert isinstance(model_input, StatefulModelInput) if execute_model_req.is_first_multi_step: # cache the worker input and model input for the next steps @@ -162,7 +167,7 @@ def prepare_input( # loop if broadcast_data is None: return None - model_input, worker_input = broadcast_data + model_input, worker_input, kwargs = broadcast_data assert isinstance(model_input, StatefulModelInput) virtual_engine = worker_input.virtual_engine if model_input.is_first_multi_step: @@ -186,4 +191,4 @@ def prepare_input( assert model_input is not None assert worker_input is not None - return model_input, worker_input + return model_input, worker_input, kwargs From 284468dfce180ccba8adb073bc4b8af07a3eb12f Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Thu, 22 Aug 2024 11:57:49 +0530 Subject: [PATCH 34/34] adding comment --- vllm/worker/worker_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 0a812fca865b7..516e386595195 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -465,7 +465,7 @@ def extract_previous_hidden_states( Dict[str, torch.Tensor]: """If data contains previous_hidden_states, extract it. This returns a dict which can be used directly as additional kwargs in any following - execute_model calls.""" + execute_model calls. This is used in draft models like EAGLE.""" output = {} # When called from non-driver worker, data is dict but when called from