|
54 | 54 | from vllm.sequence import IntermediateTensors |
55 | 55 | from vllm.utils.tensor_schema import TensorSchema, TensorShape |
56 | 56 |
|
57 | | -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
| 57 | +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, |
| 58 | + SupportsMultiModal, SupportsPP) |
58 | 59 | from .llama4 import Llama4ForCausalLM |
59 | 60 | from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix |
60 | 61 | from .vision import run_dp_sharded_vision_model |
@@ -708,8 +709,8 @@ def get_dummy_mm_data( |
708 | 709 | info=Mllama4ProcessingInfo, |
709 | 710 | dummy_inputs=Mllama4DummyInputsBuilder, |
710 | 711 | ) |
711 | | -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, |
712 | | - SupportsPP): |
| 712 | +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, |
| 713 | + SupportsEagle3): |
713 | 714 | packed_modules_mapping = { |
714 | 715 | "qkv_proj": ["q_proj", "k_proj", "v_proj"], |
715 | 716 | "gate_up_proj": ["gate_proj", "up_proj"], |
@@ -758,6 +759,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
758 | 759 | self.make_empty_intermediate_tensors = ( |
759 | 760 | self.language_model.make_empty_intermediate_tensors) |
760 | 761 |
|
| 762 | + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: |
| 763 | + """Set which layers should output auxiliary hidden states for EAGLE3.""" |
| 764 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 765 | + assert hasattr(self.language_model, 'set_aux_hidden_state_layers') |
| 766 | + self.language_model.set_aux_hidden_state_layers(layers) |
| 767 | + |
| 768 | + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: |
| 769 | + """Get the layer indices for auxiliary hidden state outputs. |
| 770 | +
|
| 771 | + Note: The GPU model runner will override this with layers from |
| 772 | + the speculative config if available, providing dynamic configuration. |
| 773 | + """ |
| 774 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 775 | + assert hasattr(self.language_model, |
| 776 | + 'get_eagle3_aux_hidden_state_layers') |
| 777 | + self.language_model.get_eagle3_aux_hidden_state_layers() |
| 778 | + |
761 | 779 | def _parse_and_validate_image_input( |
762 | 780 | self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: |
763 | 781 | # num_images, 1, num_chunks, channel, image_size, image_size |
|
0 commit comments