@@ -2753,14 +2753,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27532753
27542754 # Try to get auxiliary layers from speculative config,
27552755 # otherwise use model's default layers
2756- aux_layers = (self ._get_eagle3_aux_layers_from_config () or
2757- self .model .get_eagle3_aux_hidden_state_layers ())
2758-
2759- if aux_layers != self .model .get_eagle3_aux_hidden_state_layers (
2760- ):
2756+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2757+ if aux_layers :
27612758 logger .info (
27622759 "Using auxiliary layers from speculative config: %s" ,
27632760 aux_layers )
2761+ else :
2762+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers (
2763+ )
27642764
27652765 self .model .set_aux_hidden_state_layers (aux_layers )
27662766 time_after_load = time .perf_counter ()
@@ -2814,7 +2814,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
28142814 CUDAGraphMode .NONE , self .device )
28152815
28162816 def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
2817- """Extract Eagle3 auxiliary layer IDs from speculative config.
2817+ """Extract Eagle3 auxiliary layer indices from speculative config.
2818+
2819+ These indices specify which hidden states from the base model should
2820+ be used as auxiliary inputs for the Eagle3 drafter model during
2821+ speculative decoding.
28182822
28192823 Returns:
28202824 Tuple of layer indices if found in draft model config,
@@ -2824,18 +2828,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
28242828 and self .speculative_config .draft_model_config ):
28252829 return None
28262830
2827- try :
2828- hf_config = self .speculative_config .draft_model_config .hf_config
2829- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2830- return None
2831-
2832- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2833- if layer_ids and isinstance (layer_ids , (list , tuple )):
2834- return tuple (layer_ids )
2835- except Exception as e :
2836- logger .warning (
2837- "Failed to read auxiliary layers from speculative config: %s" ,
2838- e )
2831+ hf_config = self .speculative_config .draft_model_config .hf_config
2832+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2833+ return None
2834+
2835+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2836+ if layer_ids and isinstance (layer_ids , (list , tuple )):
2837+ return tuple (layer_ids )
28392838
28402839 return None
28412840
0 commit comments