From 9d12c23a3ec08ca97d7dec4d793f3fdf06179788 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 16 Dec 2024 02:27:07 +0000 Subject: [PATCH] spec decode: add support for EAGLE --- aphrodite/common/config.py | 4 +- aphrodite/common/sequence.py | 77 +++- aphrodite/modeling/models/__init__.py | 1 + aphrodite/modeling/models/eagle.py | 155 ++++++++ aphrodite/modeling/models/medusa.py | 19 + aphrodite/spec_decode/draft_model_runner.py | 18 + aphrodite/spec_decode/multi_step_worker.py | 11 +- aphrodite/spec_decode/spec_decode_worker.py | 91 +++-- aphrodite/task_handler/model_runner.py | 39 +- aphrodite/task_handler/multi_step_worker.py | 23 +- aphrodite/task_handler/worker.py | 2 +- aphrodite/task_handler/worker_base.py | 56 ++- aphrodite/transformers_utils/config.py | 2 + .../transformers_utils/configs/__init__.py | 2 + aphrodite/transformers_utils/configs/eagle.py | 58 +++ pytest.ini | 2 + tests/benchmarks/engine/throughput.py | 23 +- .../spec_decode/e2e/test_eagle_correctness.py | 301 +++++++++++++++ .../e2e/test_medusa_correctness.py | 62 ++- tests/spec_decode/test_multi_step_worker.py | 357 +++++++++++------- 20 files changed, 1085 insertions(+), 218 deletions(-) create mode 100644 aphrodite/modeling/models/eagle.py create mode 100644 aphrodite/transformers_utils/configs/eagle.py create mode 100644 tests/spec_decode/e2e/test_eagle_correctness.py diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index ce5e19afd..3257e8330 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -521,12 +521,14 @@ def get_hidden_size(self) -> int: def get_head_size(self) -> int: # TODO remove hard code + spec_model_types = ["medusa", "mlp_speculator"] if hasattr(self.hf_text_config, "model_type" ) and self.hf_text_config.model_type == 'deepseek_v2': # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 - if self.is_attention_free(): + if self.is_attention_free() or \ + self.hf_text_config.model_type in spec_model_types: return 0 if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim diff --git a/aphrodite/common/sequence.py b/aphrodite/common/sequence.py index 95f5cecdd..86f4c289b 100644 --- a/aphrodite/common/sequence.py +++ b/aphrodite/common/sequence.py @@ -1116,6 +1116,13 @@ class SamplerOutput( # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + def __getitem__(self, idx: int): return self.outputs[idx] @@ -1189,47 +1196,87 @@ def get_all_seq_ids_and_request_ids( return seq_ids, request_id_seq_ids_mapping -class HiddenStates( - msgspec.Struct, - omit_defaults=True, - array_like=True -): +class HiddenStates(msgspec.Struct, array_like=True, + omit_defaults=True): # type: ignore[call-arg] """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from - the target model to the proposer model in the subsequent step. - + the target model to the proposer model. seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - - seq_group_metadata_list: List[SequenceGroupMetadata] + # Scorer hidden states. For prefill step, it is used for hidden states of + # all tokens, whereas for decode step, it use used for last accepted tokens. hidden_states: torch.Tensor + # The sequence group metadata list. Only needed for decode step. + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + # Scorer hidden states of the 2nd last token proposed by the proposer ( + # irrespective of whether it was accepted or not). Only used for cases when + # last proposed token is accepted (i.e., in case of bonus tokens). For the + # case of no bonus tokens, these are ignored. + second_last_token_hidden_states: Optional[torch.Tensor] = None _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + if self.seq_group_metadata_list is not None: + assert len(self.seq_group_metadata_list) == len(self.hidden_states) + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property def seq_ids(self) -> List[int]: return self._seq_ids - def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor) -> None: - """Update hidden states from target model invocation.""" + def update(self, + hidden_states: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], + second_last_token_hidden_states: Optional[torch.Tensor] = None): + """Update hidden states from target model invocation. Only used for + decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + if self.second_last_token_hidden_states is not None: + # Adding dummy hidden_states to this to maintain same shape + self.second_last_token_hidden_states = torch.cat([ + self.second_last_token_hidden_states, + torch.zeros_like(hidden_states) + if second_last_token_hidden_states is None else + second_last_token_hidden_states + ]) + def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids.""" + """Prune to provided list of sequence ids. Only used for decode steps. + """ + # Currently this prunes all seq_ids not present in + # seq_group_metadata_list which might cause problems where a sequence + # may be "paused" then "resumed" later. This should only prune sequences + # which are confirmed to be aborted. seq_ids = get_all_seq_ids(seq_group_metadata_list) if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self._seq_ids = seq_ids + def expand_with_bonus_tokens( + self, seq_with_bonus_token_in_last_step: set) -> None: + """Expand hidden states for sequences with bonus tokens. This is in + alignment with `MultiStepWorker._expand_execute_model_request`.""" + if self.second_last_token_hidden_states is None \ + or not seq_with_bonus_token_in_last_step: + return + index = [] + for seq_id in self._seq_ids: + i = self._seq_ids.index(seq_id) + if seq_id in seq_with_bonus_token_in_last_step: + index.append(i + len(self._seq_ids)) + index.append(i) + self.hidden_states = torch.cat( + [self.hidden_states, self.second_last_token_hidden_states])[index] + class ExecuteModelRequest( msgspec.Struct, diff --git a/aphrodite/modeling/models/__init__.py b/aphrodite/modeling/models/__init__.py index 711ddb6db..6673a6527 100755 --- a/aphrodite/modeling/models/__init__.py +++ b/aphrodite/modeling/models/__init__.py @@ -61,6 +61,7 @@ "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), diff --git a/aphrodite/modeling/models/eagle.py b/aphrodite/modeling/models/eagle.py new file mode 100644 index 000000000..06331c2e6 --- /dev/null +++ b/aphrodite/modeling/models/eagle.py @@ -0,0 +1,155 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from aphrodite.attention.backends.abstract import AttentionMetadata +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.modeling.layers.logits_processor import LogitsProcessor +from aphrodite.modeling.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from aphrodite.modeling.model_loader.weight_utils import default_weight_loader +from aphrodite.modeling.models import ModelRegistry +from aphrodite.modeling.sampling_metadata import SamplingMetadata +from aphrodite.transformers_utils.configs.eagle import EAGLEConfig + + +class EAGLE(nn.Module): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) + but we do as HF implementation also does. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" + + def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: + super().__init__() + self.config = config + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + self.model = model_cls(self.config.model, *args, **kwargs) + self.fc = nn.Linear( + config.model.hidden_size * 2, config.model.hidden_size, bias=False + ) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale + ) + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + @property + def sampler(self): + return self.model.sampler + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + tok_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.fc( + torch.cat([tok_embeds, previous_hidden_states], dim=-1) + ) + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata + ) + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype, + ) + logits[..., self.token_map] = _logits + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d + model_weights = {} + for name, loaded_weight in weights: + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter( + loaded_weight, requires_grad=False + ) + elif name.startswith("fc."): + weight_loader = getattr( + self.fc.weight, "weight_loader", default_weight_loader + ) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("model.lm_head.") or name.startswith( + "model.model." + ): + model_weights[name.split("model.", 1)[-1]] = loaded_weight + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights[name] = loaded_weight + else: + model_weights[f"model.{name}"] = loaded_weight + lm_head_weight = model_weights.pop("lm_head.weight") + if ( + self.token_map is not None + and lm_head_weight.shape[0] > self.token_map.shape[0] + ): + lm_head_weight = lm_head_weight[self.token_map] + weight_loader = getattr( + self.lm_head.weight, "weight_loader", default_weight_loader + ) + weight_loader(self.lm_head.weight, lm_head_weight) + self.model.load_weights(model_weights.items()) diff --git a/aphrodite/modeling/models/medusa.py b/aphrodite/modeling/models/medusa.py index bd59de0f7..8bb59c51c 100644 --- a/aphrodite/modeling/models/medusa.py +++ b/aphrodite/modeling/models/medusa.py @@ -30,6 +30,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Medusa(nn.Module): + """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 + Reference implementation: https://github.com/FasterDecoding/Medusa + + Differences from reference implementation: + 1. Currently this only supports generating proposals from top-1 tokens. + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, config: MedusaConfig, **_) -> None: super().__init__() @@ -57,6 +70,12 @@ def __init__(self, config: MedusaConfig, **_) -> None: self.truncated_vocab_size, logit_scale) + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. self.token_map = None def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: diff --git a/aphrodite/spec_decode/draft_model_runner.py b/aphrodite/spec_decode/draft_model_runner.py index b5d6b8a93..9eda8998b 100644 --- a/aphrodite/spec_decode/draft_model_runner.py +++ b/aphrodite/spec_decode/draft_model_runner.py @@ -201,6 +201,7 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: @@ -278,13 +279,29 @@ def execute_model( graph_batch_size = model_input.input_tokens.shape[0] model_executable = (self.graph_runners[model_input.virtual_engine] [graph_batch_size]) + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None else: model_executable = self.model + hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} + kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, @@ -294,6 +311,7 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device), + **kwargs, ) # Compute the logits. diff --git a/aphrodite/spec_decode/multi_step_worker.py b/aphrodite/spec_decode/multi_step_worker.py index 6cf166f48..7e7f09dc7 100644 --- a/aphrodite/spec_decode/multi_step_worker.py +++ b/aphrodite/spec_decode/multi_step_worker.py @@ -4,8 +4,9 @@ import torch -from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceData, SequenceGroupMetadata) +from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates, + SamplerOutput, SequenceData, + SequenceGroupMetadata) from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner from aphrodite.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) @@ -157,6 +158,12 @@ def _expand_execute_model_request( updated_execute_model_req.seq_group_metadata_list =\ updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + return updated_execute_model_req, indices_of_original_sequence_groups @staticmethod diff --git a/aphrodite/spec_decode/spec_decode_worker.py b/aphrodite/spec_decode/spec_decode_worker.py index 986a6126e..25940c9aa 100644 --- a/aphrodite/spec_decode/spec_decode_worker.py +++ b/aphrodite/spec_decode/spec_decode_worker.py @@ -146,6 +146,10 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "eagle": + raise NotImplementedError( + "EAGLE does not support TP > 1 yet") allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) @@ -354,14 +358,33 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. + no_spec = num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation + # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. # This is required as if the number of draft model runs changes # dynamically, the non-driver workers won't know unless we perform a # communication to inform them. + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. broadcast_dict = dict( num_lookahead_slots=num_lookahead_slots, + no_spec=no_spec, disable_all_speculation=disable_all_speculation, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -372,17 +395,7 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # Speculative decoding is disabled in the following cases: - # 1. Prefill phase: Speculative decoding is not - # used during the prefill phase. - # 2. Auto-disable enabled: The running queue size exceeds - # the specified threshold. - # 3. No request: There are no requests in the batch. - # In any of these cases, the proposer and scorer workers - # are called normally. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list - ) == 0 or disable_all_speculation: + if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) return self._run_speculative_decoding_step(execute_model_req, @@ -463,8 +476,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - if not skip_proposer: - self.proposer_worker.execute_model(execute_model_req) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 @@ -475,10 +486,18 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if hidden_states is not None: if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) else: self.previous_hidden_states.update( - execute_model_req.seq_group_metadata_list, hidden_states) + hidden_states, execute_model_req.seq_group_metadata_list) + if not skip_proposer: + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) @@ -506,15 +525,21 @@ def _run_non_driver_rank(self) -> bool: return False num_lookahead_slots = data["num_lookahead_slots"] - # Even if num_lookahead_slots is zero, we want to run the proposer model - # as it may have KV. - # - # We run the proposer once per lookahead slot. In the future we should - # delegate how many times it runs to the proposer. - for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model() + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. + if data["no_spec"]: + self.scorer_worker.execute_model() + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + if not data["no_spec"]: + self.scorer_worker.execute_model() - self.scorer_worker.execute_model() return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") @@ -545,6 +570,8 @@ def _run_speculative_decoding_step( raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") + execute_model_req.previous_hidden_states = None + with Timer() as scoring_timer: proposal_scores = self.scorer.score_proposals( execute_model_req, @@ -649,10 +676,12 @@ def _verify_tokens( accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) + second_last_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates(seq_group_metadata_list, - hidden_states) + self.previous_hidden_states = HiddenStates( + hidden_states, seq_group_metadata_list, + second_last_token_hidden_states) return accepted_token_ids, logprobs @@ -948,3 +977,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/aphrodite/task_handler/model_runner.py b/aphrodite/task_handler/model_runner.py index 9c4f4dc15..11cd1ee54 100644 --- a/aphrodite/task_handler/model_runner.py +++ b/aphrodite/task_handler/model_runner.py @@ -1,5 +1,6 @@ import dataclasses import gc +import inspect import itertools import os import time @@ -1242,6 +1243,16 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # Prepare dummy previous_hidden_states only if needed by the model. + # This is used by draft models such as EAGLE. + previous_hidden_states = None + if "previous_hidden_states" in inspect.signature( + self.model.forward).parameters: + previous_hidden_states = torch.empty( + [max_batch_size, + self.model_config.get_hidden_size()], + dtype=self.model_config.dtype, + device=self.device) intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1316,6 +1327,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "stream": graph_capture_context.stream } + if previous_hidden_states is not None: + capture_inputs[ + "previous_hidden_states"] = previous_hidden_states[: + batch_size] if self.has_seqlen_agnostic: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ @@ -1475,6 +1490,7 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) + output.prefill_hidden_states = hidden_or_intermediate_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: @@ -1523,11 +1539,11 @@ def capture( # Note one iteration is not enough for torch.jit.script for _ in range(_NUM_WARMUP_ITERS): self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) torch.cuda.synchronize() @@ -1536,11 +1552,11 @@ def capture( self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_inputs, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, **kwargs, ) if hidden_or_intermediate_states is not None: @@ -1602,6 +1618,9 @@ def forward( if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) + if "previous_hidden_states" in self.input_buffers: + self.input_buffers["previous_hidden_states"].copy_( + kwargs["previous_hidden_states"], non_blocking=True) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: self.input_buffers[key].copy_(intermediate_tensors[key], diff --git a/aphrodite/task_handler/multi_step_worker.py b/aphrodite/task_handler/multi_step_worker.py index 241cf5fad..c6fa70255 100644 --- a/aphrodite/task_handler/multi_step_worker.py +++ b/aphrodite/task_handler/multi_step_worker.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +import torch from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput from aphrodite.distributed import broadcast_tensor_dict, get_pp_group @@ -41,7 +43,7 @@ def __init__(self, *args, **kwargs): def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ @@ -83,7 +85,10 @@ def _get_driver_input_and_broadcast( broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update(model_input.as_broadcastable_tensor_dict()) broadcast_tensor_dict(broadcast_data, src=0) - return model_input, worker_input + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} def _prepare_last_sampled_token_ids_for_tp_workers( self, @@ -132,7 +137,8 @@ def _prepare_last_sampled_token_ids_for_tp_workers( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: """ Depending on the current state of the request and multi step worker, this method may skip the normal _prepare_model_input and @@ -149,9 +155,8 @@ def prepare_input( broadcast_tensor_dict({}, src=0) return None virtual_engine = execute_model_req.virtual_engine - model_input, worker_input = self._get_driver_input_and_broadcast( - execute_model_req - ) + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) assert isinstance(model_input, StatefulModelInput) if execute_model_req.is_first_multi_step: # cache the worker input and model input for the next steps @@ -165,7 +170,7 @@ def prepare_input( # loop if broadcast_data is None: return None - model_input, worker_input = broadcast_data + model_input, worker_input, kwargs = broadcast_data assert isinstance(model_input, StatefulModelInput) virtual_engine = worker_input.virtual_engine if model_input.is_first_multi_step: @@ -188,4 +193,4 @@ def prepare_input( ) assert model_input is not None assert worker_input is not None - return model_input, worker_input + return model_input, worker_input, kwargs diff --git a/aphrodite/task_handler/worker.py b/aphrodite/task_handler/worker.py index 77089c7eb..d65f9435d 100644 --- a/aphrodite/task_handler/worker.py +++ b/aphrodite/task_handler/worker.py @@ -87,7 +87,7 @@ def __init__( or (speculative_config.draft_model_config.model == model_config.model) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator"]) \ + not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/aphrodite/task_handler/worker_base.py b/aphrodite/task_handler/worker_base.py index ee2a5b05b..7dcc91c6c 100644 --- a/aphrodite/task_handler/worker_base.py +++ b/aphrodite/task_handler/worker_base.py @@ -218,7 +218,9 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + self + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -231,11 +233,13 @@ def _get_worker_input_from_broadcast( self.model_runner.make_model_input_from_broadcasted_tensor_dict( broadcast_data)) - return model_input, worker_input + kwargs = extract_previous_hidden_states(broadcast_data) + return model_input, worker_input, kwargs + def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -247,17 +251,21 @@ def _get_driver_input_and_broadcast( execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + kwargs = extract_previous_hidden_states(execute_model_req) + if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) broadcast_tensor_dict(broadcast_data, src=0) - return model_input, worker_input + return model_input, worker_input, kwargs def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: """ Prepare the inputs to ModelRunner and workers. """ @@ -286,7 +294,7 @@ def execute_model( if inputs is None: return None - model_input, worker_input = inputs + model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps self.execute_worker(worker_input) # If there is no input, we don't need to execute the model. @@ -300,9 +308,13 @@ def execute_model( all_gather_group=get_tp_group())) output = self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) if not get_pp_group().is_last_rank: # output is IntermediateTensors @@ -338,9 +350,14 @@ def _execute_model_spmd( if worker_input.num_seq_groups == 0: return [] + kwargs = extract_previous_hidden_states(execute_model_req) return self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors) + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) class WorkerWrapperBase: @@ -416,3 +433,20 @@ def execute_method(self, method, *args, **kwargs): "This might cause deadlock in distributed execution.") logger.exception(msg) raise e + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ + Dict[str, torch.Tensor]: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls. This is used in draft models like EAGLE.""" + output = {} + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + return output diff --git a/aphrodite/transformers_utils/config.py b/aphrodite/transformers_utils/config.py index 83ff23f09..ed27e2a5b 100644 --- a/aphrodite/transformers_utils/config.py +++ b/aphrodite/transformers_utils/config.py @@ -17,6 +17,7 @@ import aphrodite.common.envs as envs from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + EAGLEConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, @@ -44,6 +45,7 @@ "medusa": MedusaConfig, "internvl_chat": InternVLChatConfig, "ultravox": UltravoxConfig, + "eagle": EAGLEConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/aphrodite/transformers_utils/configs/__init__.py b/aphrodite/transformers_utils/configs/__init__.py index 2b1dd1469..edc3d0c35 100644 --- a/aphrodite/transformers_utils/configs/__init__.py +++ b/aphrodite/transformers_utils/configs/__init__.py @@ -1,5 +1,6 @@ from aphrodite.transformers_utils.configs.chatglm import ChatGLMConfig from aphrodite.transformers_utils.configs.dbrx import DbrxConfig +from aphrodite.transformers_utils.configs.eagle import EAGLEConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -22,4 +23,5 @@ "MLPSpeculatorConfig", "MedusaConfig", "UltravoxConfig", + "EAGLEConfig", ] diff --git a/aphrodite/transformers_utils/configs/eagle.py b/aphrodite/transformers_utils/configs/eagle.py new file mode 100644 index 000000000..09f5b7c89 --- /dev/null +++ b/aphrodite/transformers_utils/configs/eagle.py @@ -0,0 +1,58 @@ +import os +from typing import Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class EAGLEConfig(PretrainedConfig): + model_type = "eagle" + + def __init__( + self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + **kwargs, + ): + model_config = ( + None + if model is None + else ( + AutoConfig.for_model(**model) + if isinstance(model, dict) + else model + ) + ) + for k, v in kwargs.items(): + if ( + k != "architectures" + and k != "model_type" + and hasattr(model_config, k) + ): + setattr(model_config, k, v) + self.model = model_config + if self.model is None: + self.truncated_vocab_size = None + else: + self.truncated_vocab_size = ( + self.model.vocab_size + if truncated_vocab_size is None + else truncated_vocab_size + ) + if "architectures" not in kwargs: + kwargs["architectures"] = ["EAGLEModel"] + super().__init__(**kwargs) + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "EAGLEConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + return cls.from_dict(config_dict, **kwargs) diff --git a/pytest.ini b/pytest.ini index 7600b0470..880444d04 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,5 +3,7 @@ filterwarnings = ignore::DeprecationWarning:pkg_resources.* ignore::DeprecationWarning:pyairports.* ignore::pytest_asyncio.plugin.PytestDeprecationWarning + ignore::FutureWarning:xformers.* + ignore::FutureWarning:torch.* asyncio_mode = strict asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/tests/benchmarks/engine/throughput.py b/tests/benchmarks/engine/throughput.py index 0328f97a3..b5907bdc9 100644 --- a/tests/benchmarks/engine/throughput.py +++ b/tests/benchmarks/engine/throughput.py @@ -88,6 +88,9 @@ def run_aphrodite( load_format: str = EngineArgs.load_format, max_num_seqs: Optional[int] = None, num_scheduler_steps: Optional[int] = None, + speculative_model: Optional[str] = None, + num_speculative_tokens: Optional[int] = None, + use_v2_block_manager: bool = False, ) -> float: from aphrodite import LLM, SamplingParams llm = LLM( @@ -114,6 +117,9 @@ def run_aphrodite( load_format=load_format, max_num_seqs=max_num_seqs, num_scheduler_steps=num_scheduler_steps, + speculative_model=speculative_model, + num_speculative_tokens=num_speculative_tokens, + use_v2_block_manager=use_v2_block_manager, ) # Add the requests to the engine. @@ -242,7 +248,8 @@ def main(args: argparse.Namespace): args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.download_dir, args.load_format, - args.max_num_seqs, args.num_scheduler_steps) + args.max_num_seqs, args.num_scheduler_steps, args.speculative_model, + args.num_speculative_tokens, args.use_v2_block_manager) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -435,6 +442,20 @@ def main(args: argparse.Namespace): 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + '--speculative-model', + type=str, + default=None, + help='speculative model for spec decode.') + parser.add_argument( + '--num-speculative-tokens', + type=int, + default=None, + help='number of speculative tokens for spec decode.') + parser.add_argument( + '--use-v2-block-manager', + action='store_true', + help='use v2 block manager for spec decode.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py new file mode 100644 index 000000000..d3659792e --- /dev/null +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -0,0 +1,301 @@ +"""This docstring details important information on the testing methodology. +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. +With those tests, we can say at least, EAGLE would not break the +correctess for the target model outputs. +""" +import pytest + +from .conftest import run_greedy_equality_correctness_test + +# main model +MAIN_MODEL = "JackFram/llama-68m" +# speculative model +SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 4 +# precision +PRECISION = "float32" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. + "use_v2_block_manager": True, + # Print spec metrics. + "disable_log_stats": False, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + # Get the safetensors model + "revision": "refs/pr/9", + } + ], +) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + ], +) +@pytest.mark.parametrize( + "output_len", + [ + 128, + ], +) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness( + baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int +): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + "enforce_eager": False, + # Required for spec decode. + "use_v2_block_manager": True, + # Print spec metrics. + "disable_log_stats": False, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + # Get the safetensors model + "revision": "refs/pr/9", + } + ], +) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + ], +) +@pytest.mark.parametrize( + "output_len", + [ + 128, + ], +) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_cuda_graph( + baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int +): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. + "use_v2_block_manager": True, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + # Get the safetensors model + "revision": "refs/pr/9", + } + ], +) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + ], +) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ], +) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int +): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. + "use_v2_block_manager": True, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + # Get the safetensors model + "revision": "refs/pr/9", + } + ], +) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ], +) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ], +) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_different_k( + baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int +): + """Verify that eagle speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. + "use_v2_block_manager": True, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + # Get the safetensors model + "revision": "refs/pr/9", + } + ], +) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4, + } + ], +) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ], +) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_disable_queue( + baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int +): + """Verify that eagle speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7e4a6cc62..2e178818d 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -70,8 +70,9 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_medusa_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -79,6 +80,43 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, max_output_len=output_len, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + # Required for spec decode. + "use_v2_block_manager": True, + # Print spec metrics. + "disable_log_stats": False, + # Precision + "dtype": PRECISION, + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) @pytest.mark.parametrize( "common_llm_kwargs", @@ -116,10 +154,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size: int, - output_len: int): +def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ @@ -165,9 +203,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_different_k(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality +def test_medusa_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that medusa speculative decoding produces exact equality to without spec decode with different values of num_speculative_tokens. """ run_greedy_equality_correctness_test(baseline_llm_generator, @@ -208,9 +246,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality +def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that medusa speculative decoding produces exact equality to without spec decode when speculation is disabled for large batch sizes. """ diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index b205cab64..dad084fcc 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,8 +5,8 @@ import pytest import torch -from aphrodite.common.sequence import (ExecuteModelRequest, Logprob, - SamplerOutput) +from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates, + Logprob, SamplerOutput, get_all_seq_ids) from aphrodite.modeling.utils import set_random_seed from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner from aphrodite.spec_decode.multi_step_worker import MultiStepWorker @@ -18,7 +18,7 @@ patch_execute_model_with_seeds, zero_kv_cache) -@pytest.mark.parametrize('num_steps', list(range(1, 17))) +@pytest.mark.parametrize("num_steps", list(range(1, 17))) def test_assert_enough_kv_space(num_steps: int): """Test that the multi step worker checks for sufficient space in the KV cache. It should throw if it cannot run all the steps. @@ -46,7 +46,8 @@ def test_assert_enough_kv_space(num_steps: int): num_gpu_blocks, block_size, final_prompt_lens, - continuations=prev_output_tokens) + continuations=prev_output_tokens, + ) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access worker = MagicMock() @@ -64,8 +65,9 @@ def test_assert_enough_kv_space(num_steps: int): } # Expect exception. - with pytest.raises(ValueError, - match='times but found insufficient KV space for'): + with pytest.raises( + ValueError, match="times but found insufficient KV space for" + ): assert_enough_kv_space(worker, inputs, num_steps) seq_group_metadata.block_tables = original_block_tables @@ -77,7 +79,7 @@ def test_same_output_for_single_step(): worker for num_steps=1. """ seed = 100 - model_name = 'JackFram/llama-68m' + model_name = "JackFram/llama-68m" block_size = 32 num_gpu_blocks = 2048 // block_size @@ -109,32 +111,32 @@ def test_same_output_for_single_step(): final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] multi_step_seq_group = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) + prompts, num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens + ) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) actual_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=multi_step_seq_group), + seq_group_metadata_list=multi_step_seq_group + ), sample_len=num_steps, - seq_ids_with_bonus_token_in_last_step=set()) + seq_ids_with_bonus_token_in_last_step=set(), + ) assert len(actual_output) == num_steps actual_output = actual_output[0] single_step_seq_group = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) + prompts, num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens + ) zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=single_step_seq_group))[0] + seq_group_metadata_list=single_step_seq_group + ) + )[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -150,8 +152,8 @@ def test_same_output_for_single_step(): assert actual_token_ids == expected_token_ids - print(f'{actual_logprobs=}') - print(f'{expected_logprobs=}') + print(f"{actual_logprobs=}") + print(f"{expected_logprobs=}") assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs) @@ -162,7 +164,7 @@ def test_same_output_for_multi_step(): then runs the worker num_steps times, and compares the output. """ seed = 100 - model_name = 'JackFram/llama-68m' + model_name = "JackFram/llama-68m" block_size = 16 num_gpu_blocks = 2048 // block_size @@ -187,15 +189,17 @@ def test_same_output_for_multi_step(): num_steps = block_size + 1 random.seed(seed) - prompts = [[ - random.randint(0, 1000) for _ in range(random.randint(10, 20)) - ] for _ in range(10)] + prompts = [ + [random.randint(0, 1000) for _ in range(random.randint(10, 20))] + for _ in range(10) + ] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) + multi_step_worker, rand_seeds + ) worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) continuations = [[1] for _ in prompts] @@ -204,16 +208,19 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) multi_step_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), + seq_group_metadata_list=seq_group_metadata_list + ), sample_len=num_steps, - seq_ids_with_bonus_token_in_last_step=set()) + seq_ids_with_bonus_token_in_last_step=set(), + ) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -222,62 +229,73 @@ def test_same_output_for_multi_step(): set_random_seed(seed) for _ in multi_step_output: - seq_group_metadata_list = create_seq_group_metadata_from_prompts( prompts, num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) + worker.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list + ) + ) + ) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): continuations[i].append(seq_group_output.samples[0].output_token) # Get token ids and logprobs for comparison. - multi_step_output_logprobs: List[List[Dict[int, - Logprob]]] = [[] - for _ in prompts] - single_step_output_logprobs: List[List[Dict[int, - Logprob]]] = [[] - for _ in prompts] + multi_step_output_logprobs: List[List[Dict[int, Logprob]]] = [ + [] for _ in prompts + ] + single_step_output_logprobs: List[List[Dict[int, Logprob]]] = [ + [] for _ in prompts + ] multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts] single_step_output_token_ids: List[List[int]] = [[] for _ in prompts] for i, _ in enumerate(prompts): - for multi_step, single_step in zip(multi_step_output, - single_step_output): + for multi_step, single_step in zip( + multi_step_output, single_step_output + ): multi_step_output_token_ids[i].append( - multi_step[i].samples[0].output_token) + multi_step[i].samples[0].output_token + ) single_step_output_token_ids[i].append( - single_step[i].samples[0].output_token) + single_step[i].samples[0].output_token + ) multi_step_output_logprobs[i].append( - multi_step[i].samples[0].logprobs) + multi_step[i].samples[0].logprobs + ) single_step_output_logprobs[i].append( - single_step[i].samples[0].logprobs) + single_step[i].samples[0].logprobs + ) # Print per-sequence token ids for i, (multi_step_tokens, single_step_tokens) in enumerate( - zip(multi_step_output_token_ids, single_step_output_token_ids)): - print(f'{i=} {multi_step_tokens=}') - print(f'{i=} {single_step_tokens=}') - print(f'{i=} equal {multi_step_tokens == single_step_tokens}') + zip(multi_step_output_token_ids, single_step_output_token_ids) + ): + print(f"{i=} {multi_step_tokens=}") + print(f"{i=} {single_step_tokens=}") + print(f"{i=} equal {multi_step_tokens == single_step_tokens}") # Assert token ids are equal. for multi_step_tokens, single_step_tokens in zip( - multi_step_output_token_ids, single_step_output_token_ids): + multi_step_output_token_ids, single_step_output_token_ids + ): assert multi_step_tokens == single_step_tokens # Assert logprobs are equal. for multi_step_logprobs, single_step_logprobs in zip( - multi_step_output_logprobs, single_step_output_logprobs): - assert_logprobs_dict_allclose(multi_step_logprobs, - single_step_logprobs) + multi_step_output_logprobs, single_step_output_logprobs + ): + assert_logprobs_dict_allclose(multi_step_logprobs, single_step_logprobs) @torch.inference_mode() @@ -290,7 +308,7 @@ def test_multi_step_with_batch_expansion_correct_output(): expanded batch is then used for predicting the next tokens. """ seed = 100 - model_name = 'JackFram/llama-68m' + model_name = "JackFram/llama-68m" block_size = 16 num_gpu_blocks = 2048 // block_size @@ -316,7 +334,8 @@ def test_multi_step_with_batch_expansion_correct_output(): final_prompt_lens = [(num_steps + 1) for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) + multi_step_worker, rand_seeds + ) worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) # Create the test continuations continuations = [[random.randint(0, 1000)] for _ in prompts] @@ -325,7 +344,8 @@ def test_multi_step_with_batch_expansion_correct_output(): num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) # Run single-step twice to generate 2 tokens. This # will simulate the bonus token case with the second token @@ -339,10 +359,15 @@ def test_multi_step_with_batch_expansion_correct_output(): num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) + worker.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list + ) + ) + ) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): continuations[i].append(seq_group_output.samples[0].output_token) @@ -357,7 +382,8 @@ def test_multi_step_with_batch_expansion_correct_output(): num_gpu_blocks, block_size, continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) # Run multi-step and verify that the third token prediction is accurate # for all sequences. @@ -365,11 +391,13 @@ def test_multi_step_with_batch_expansion_correct_output(): all_seq_ids = {i for i in range(batch_size)} multi_step_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), + seq_group_metadata_list=seq_group_metadata_list + ), sample_len=1, - seq_ids_with_bonus_token_in_last_step=all_seq_ids) + seq_ids_with_bonus_token_in_last_step=all_seq_ids, + ) for index, output in enumerate(multi_step_output[-1].outputs): - assert (continuations[index][-1] == output.samples[0].output_token) + assert continuations[index][-1] == output.samples[0].output_token @torch.inference_mode() @@ -384,7 +412,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): the sequence ID is specified incorrectly. """ seed = 100 - model_name = 'JackFram/llama-68m' + model_name = "JackFram/llama-68m" block_size = 16 num_gpu_blocks = 2048 // block_size @@ -410,7 +438,8 @@ def test_multi_step_with_batch_expansion_incorrect_output(): final_prompt_lens = [(num_steps + 1) for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) + multi_step_worker, rand_seeds + ) worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) # Create the test continuations continuations = [[random.randint(0, 1000)] for _ in prompts] @@ -419,7 +448,8 @@ def test_multi_step_with_batch_expansion_incorrect_output(): num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) # Run single-step twice to generate 2 tokens. This # will simulate the bonus token case with the second token # being the bonus token. @@ -432,10 +462,15 @@ def test_multi_step_with_batch_expansion_incorrect_output(): num_gpu_blocks, block_size, continuations=continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) + worker.execute_model( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list + ) + ) + ) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): continuations[i].append(seq_group_output.samples[0].output_token) @@ -450,7 +485,8 @@ def test_multi_step_with_batch_expansion_incorrect_output(): num_gpu_blocks, block_size, continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) + final_prompt_lens=final_prompt_lens, + ) # Run multi-step. In this run INCORRECTLY specify that only the odd number # sequences have bonus tokens. Verify that with this setting the third token @@ -462,19 +498,21 @@ def test_multi_step_with_batch_expansion_incorrect_output(): odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0} multi_step_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), + seq_group_metadata_list=seq_group_metadata_list + ), sample_len=1, - seq_ids_with_bonus_token_in_last_step=odd_seq_ids) + seq_ids_with_bonus_token_in_last_step=odd_seq_ids, + ) num_mismatch = 0 for index, output in enumerate(multi_step_output[-1].outputs): if (index % 2) != 0: - assert (continuations[index][-1] == output.samples[0].output_token) - elif (continuations[index][-1] != output.samples[0].output_token): + assert continuations[index][-1] == output.samples[0].output_token + elif continuations[index][-1] != output.samples[0].output_token: num_mismatch += 1 # The prediction is accurate for some of the sequences even without proper # handling of the bonus tokens. Hence verify that the number of sequences # for which there is a mismatch is > 0. - assert (num_mismatch > 0) + assert num_mismatch > 0 @torch.inference_mode() @@ -485,7 +523,7 @@ def test_draft_proposals_full_speculation_len(): k = 10 batch_size = 32 vocab_size = 32_000 - device = 'cuda:0' + device = "cuda:0" draft_worker = MagicMock() proposer = Top1Proposer( @@ -494,32 +532,38 @@ def test_draft_proposals_full_speculation_len(): vocab_size=vocab_size, max_proposal_len=2048, ) - draft_worker.sampler_output.return_value = [ - SamplerOutput( - outputs=[], - sampled_token_probs=torch.rand(batch_size, - vocab_size, - device=device, - dtype=torch.float32), - logprobs=torch.rand(batch_size, - vocab_size, - device=device, - dtype=torch.float32), - sampled_token_ids=torch.randint(low=0, - high=vocab_size, - size=(batch_size, ), - device=device, - dtype=torch.long), - ) for _ in range(k) - ], True + draft_worker.sampler_output.return_value = ( + [ + SamplerOutput( + outputs=[], + sampled_token_probs=torch.rand( + batch_size, vocab_size, device=device, dtype=torch.float32 + ), + logprobs=torch.rand( + batch_size, vocab_size, device=device, dtype=torch.float32 + ), + sampled_token_ids=torch.randint( + low=0, + high=vocab_size, + size=(batch_size,), + device=device, + dtype=torch.long, + ), + ) + for _ in range(k) + ], + True, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k) proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) + num_lookahead_slots=k, + ), + seq_ids_with_bonus_token_in_last_step=set(), + ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -539,7 +583,7 @@ def test_draft_proposals_no_speculations(): k = 10 batch_size = 32 vocab_size = 32_000 - device = 'cuda:0' + device = "cuda:0" prompt_len = 10 draft_worker = MagicMock() @@ -550,15 +594,17 @@ def test_draft_proposals_no_speculations(): max_proposal_len=prompt_len + k - 1, ) - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) + seq_group_metadata_list, _, _ = create_batch( + batch_size, k, prompt_len=prompt_len + ) proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) + num_lookahead_slots=k, + ), + seq_ids_with_bonus_token_in_last_step=set(), + ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -578,7 +624,7 @@ def test_draft_proposals_mixed_k(): k = 10 batch_size = 32 vocab_size = 32_000 - device = 'cuda:0' + device = "cuda:0" small_prompt_len = 5 long_prompt_len = 10 @@ -587,10 +633,11 @@ def test_draft_proposals_mixed_k(): expected_num_proposal_seqs = 6 expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs - prompt_len = [ - small_prompt_len for _ in range(expected_num_proposal_seqs - 1) - ] + [long_prompt_len - for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] + prompt_len = ( + [small_prompt_len for _ in range(expected_num_proposal_seqs - 1)] + + [long_prompt_len for _ in range(expected_num_no_proposal_seqs)] + + [small_prompt_len] + ) draft_worker = MagicMock() proposer = Top1Proposer( @@ -600,25 +647,34 @@ def test_draft_proposals_mixed_k(): max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, ) - draft_worker.sampler_output.return_value = [ - SamplerOutput( - outputs=[], - sampled_token_probs=torch.rand(expected_num_proposal_seqs, - vocab_size, - device=device, - dtype=torch.float32), - logprobs=torch.rand(expected_num_proposal_seqs, - vocab_size, - device=device, - dtype=torch.float32), - sampled_token_ids=torch.randint( - low=0, - high=vocab_size, - size=(expected_num_proposal_seqs, ), - device=device, - dtype=torch.long), - ) for _ in range(k) - ], True + draft_worker.sampler_output.return_value = ( + [ + SamplerOutput( + outputs=[], + sampled_token_probs=torch.rand( + expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32, + ), + logprobs=torch.rand( + expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32, + ), + sampled_token_ids=torch.randint( + low=0, + high=vocab_size, + size=(expected_num_proposal_seqs,), + device=device, + dtype=torch.long, + ), + ) + for _ in range(k) + ], + True, + ) seq_group_metadata_list, _, _ = create_batch( batch_size, @@ -630,8 +686,10 @@ def test_draft_proposals_mixed_k(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) + num_lookahead_slots=k, + ), + seq_ids_with_bonus_token_in_last_step=set(), + ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -651,7 +709,7 @@ def test_use_draft_model_runner_advance_step(): when applicable. """ seed = 100 - model_name = 'JackFram/llama-68m' + model_name = "JackFram/llama-68m" k = 5 batch_size = 32 @@ -670,7 +728,8 @@ def test_use_draft_model_runner_advance_step(): exception_secret = "artificial stop" worker.model_runner._gpu_advance_step = MagicMock() worker.model_runner._gpu_advance_step.side_effect = ValueError( - exception_secret) + exception_secret + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k) @@ -678,16 +737,52 @@ def test_use_draft_model_runner_advance_step(): execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k, - num_steps=1) + num_steps=1, + ) worker.execute_model(execute_model_req=execute_model_req) # Expect exception if _gpu_advance_step is called. execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k, - num_steps=k) + num_steps=k, + ) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=execute_model_req) call_args_list = worker.model_runner._gpu_advance_step.call_args_list assert len(call_args_list) == 1 + + +@torch.inference_mode() +def test_expand_execute_model_request_sync_with_expand_hidden_states(): + """ + In this test we verify that the logic for expanding the + seq_group_metadata_list remains in sync with the expansion logic of + the HiddenStates in _expand_execute_model_request. + """ + k = 5 + batch_size = 16 + seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_request = ExecuteModelRequest( + seq_group_metadata_list, + previous_hidden_states=HiddenStates( + torch.arange(batch_size), + seq_group_metadata_list, + torch.arange(batch_size, 2 * batch_size), + ), + ) + ( + expanded_execute_model_request, + orig_seq_group_ids, + ) = MultiStepWorker._expand_execute_model_request( + execute_model_request, seq_with_bonus_token_in_last_step + ) + all_seq_ids = torch.tensor( + get_all_seq_ids(expanded_execute_model_request.seq_group_metadata_list) + ) + ref_expanded_hidden_states = all_seq_ids + batch_size + ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size + assert (ref_expanded_hidden_states == expanded_execute_model_request. + previous_hidden_states.hidden_states).all().item()