File tree Expand file tree Collapse file tree 3 files changed +2
-18
lines changed Expand file tree Collapse file tree 3 files changed +2
-18
lines changed Original file line number Diff line number Diff line change @@ -252,7 +252,6 @@ def create_py_executor(
252252 with mem_monitor .observe_creation_stage (
253253 _ExecutorCreationStage .MODEL_ENGINE_DRAFT ):
254254 draft_spec_config = copy .copy (spec_config )
255- draft_spec_config .update_for_draft_init ()
256255 draft_pytorch_backend_config = copy .copy (pytorch_backend_config )
257256 if spec_config .load_format == "dummy" :
258257 draft_pytorch_backend_config .load_format = LoadFormat .DUMMY
Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ class Eagle3SpecMetadata(SpecMetadata):
9191
9292 def __post_init__ (self ):
9393 if self .layers_to_capture is None :
94- if self .num_layers == 1 :
94+ if self .is_draft_model or self . num_layers == 1 :
9595 self .layers_to_capture = (self .num_layers - 1 , )
9696 else :
9797 if self .num_layers <= 5 :
Original file line number Diff line number Diff line change @@ -391,11 +391,6 @@ def validate(self) -> None:
391391 Do any additional error checking here.
392392 """
393393
394- def update_for_draft_init (self ):
395- """
396- Update the config for draft model initialization.
397- """
398-
399394 @functools .cached_property
400395 def spec_dec_mode (self ):
401396 # spec_dec_mode has more functionality than the raw decoding_mode string.
@@ -450,7 +445,7 @@ def spec_dec_mode(self):
450445 return TorchSpeculativeDecodingMode .EAGLE3_ONE_MODEL
451446 return TorchSpeculativeDecodingMode .EAGLE3
452447
453- @property
448+ @functools . cached_property
454449 def num_capture_layers (self ):
455450 """
456451 Returns the number of layers to capture of the target model.
@@ -461,16 +456,6 @@ def num_capture_layers(self):
461456 return len (self .eagle3_layers_to_capture )
462457 return 3
463458
464- def update_for_draft_init (self ):
465- """
466- Update the config for draft model initialization.
467- """
468- if not self .eagle3_one_model :
469- num_layers = self .num_eagle_layers
470- if num_layers is None :
471- num_layers = 1
472- self .eagle3_layers_to_capture = set (num_layers - 1 )
473-
474459
475460class UserProvidedDecodingConfig (DecodingBaseConfig ):
476461 # Cannot use real type annotations due to circular imports
You can’t perform that action at this time.
0 commit comments