Skip to content

Commit

Permalink
spec decode: add support for EAGLE
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 16, 2024
1 parent bfc3da4 commit 9d12c23
Show file tree
Hide file tree
Showing 20 changed files with 1,085 additions and 218 deletions.
4 changes: 3 additions & 1 deletion aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,14 @@ def get_hidden_size(self) -> int:

def get_head_size(self) -> int:
# TODO remove hard code
spec_model_types = ["medusa", "mlp_speculator"]
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if self.is_attention_free():
if self.is_attention_free() or \
self.hf_text_config.model_type in spec_model_types:
return 0
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
Expand Down
77 changes: 62 additions & 15 deletions aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,13 @@ class SamplerOutput(
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None

# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None

# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None

def __getitem__(self, idx: int):
return self.outputs[idx]

Expand Down Expand Up @@ -1189,47 +1196,87 @@ def get_all_seq_ids_and_request_ids(
return seq_ids, request_id_seq_ids_mapping


class HiddenStates(
msgspec.Struct,
omit_defaults=True,
array_like=True
):
class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
the target model to the proposer model.
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""

seq_group_metadata_list: List[SequenceGroupMetadata]
# Scorer hidden states. For prefill step, it is used for hidden states of
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)

def __post_init__(self):
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

@property
def seq_ids(self) -> List[int]:
return self._seq_ids

def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
def update(self,
hidden_states: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata],
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])

if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states,
torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else
second_last_token_hidden_states
])

def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
# seq_group_metadata_list which might cause problems where a sequence
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
if self.second_last_token_hidden_states is not None:
self.second_last_token_hidden_states = self\
.second_last_token_hidden_states[index]
self._seq_ids = seq_ids

def expand_with_bonus_tokens(
self, seq_with_bonus_token_in_last_step: set) -> None:
"""Expand hidden states for sequences with bonus tokens. This is in
alignment with `MultiStepWorker._expand_execute_model_request`."""
if self.second_last_token_hidden_states is None \
or not seq_with_bonus_token_in_last_step:
return
index = []
for seq_id in self._seq_ids:
i = self._seq_ids.index(seq_id)
if seq_id in seq_with_bonus_token_in_last_step:
index.append(i + len(self._seq_ids))
index.append(i)
self.hidden_states = torch.cat(
[self.hidden_states, self.second_last_token_hidden_states])[index]


class ExecuteModelRequest(
msgspec.Struct,
Expand Down
1 change: 1 addition & 0 deletions aphrodite/modeling/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
Expand Down
155 changes: 155 additions & 0 deletions aphrodite/modeling/models/eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn as nn

from aphrodite.attention.backends.abstract import AttentionMetadata
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
from aphrodite.modeling.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
from aphrodite.modeling.models import ModelRegistry
from aphrodite.modeling.sampling_metadata import SamplingMetadata
from aphrodite.transformers_utils.configs.eagle import EAGLEConfig


class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""

def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
super().__init__()
self.config = config
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(
config.model.hidden_size * 2, config.model.hidden_size, bias=False
)
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None

@property
def sampler(self):
return self.model.sampler

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids)
inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1)
)
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
hidden_states = self.model.model(
input_ids=None,
inputs_embeds=inputs_embeds,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
)
return hidden_states

def compute_logits(
self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
) -> torch.Tensor:
logits = self.logits_processor(
self.lm_head, hidden_states, sampling_metadata
)
if self.token_map is not None:
_logits = logits
logits = -torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype,
)
logits[..., self.token_map] = _logits
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights = {}
for name, loaded_weight in weights:
if name == "token_map":
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(
loaded_weight, requires_grad=False
)
elif name.startswith("fc."):
weight_loader = getattr(
self.fc.weight, "weight_loader", default_weight_loader
)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."
):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
elif name.startswith("lm_head.") or name.startswith("model."):
model_weights[name] = loaded_weight
else:
model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight")
if (
self.token_map is not None
and lm_head_weight.shape[0] > self.token_map.shape[0]
):
lm_head_weight = lm_head_weight[self.token_map]
weight_loader = getattr(
self.lm_head.weight, "weight_loader", default_weight_loader
)
weight_loader(self.lm_head.weight, lm_head_weight)
self.model.load_weights(model_weights.items())
19 changes: 19 additions & 0 deletions aphrodite/modeling/models/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class Medusa(nn.Module):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""

def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
Expand Down Expand Up @@ -57,6 +70,12 @@ def __init__(self, config: MedusaConfig, **_) -> None:
self.truncated_vocab_size,
logit_scale)

# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None

def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
Expand Down
18 changes: 18 additions & 0 deletions aphrodite/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
Expand Down Expand Up @@ -278,13 +279,29 @@ def execute_model(
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
if previous_hidden_states is not None:
hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
hidden_states = None
else:
model_executable = self.model
hidden_states = previous_hidden_states

outputs: List[SamplerOutput] = []
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}

kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}

# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
Expand All @@ -294,6 +311,7 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
)

# Compute the logits.
Expand Down
Loading

0 comments on commit 9d12c23

Please sign in to comment.