Skip to content

Commit ee12005

Browse files
committed
revert unnec changes
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 9e57506 commit ee12005

File tree

3 files changed

+2
-18
lines changed

3 files changed

+2
-18
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255-
draft_spec_config.update_for_draft_init()
256255
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
257256
if spec_config.load_format == "dummy":
258257
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class Eagle3SpecMetadata(SpecMetadata):
9191

9292
def __post_init__(self):
9393
if self.layers_to_capture is None:
94-
if self.num_layers == 1:
94+
if self.is_draft_model or self.num_layers == 1:
9595
self.layers_to_capture = (self.num_layers - 1, )
9696
else:
9797
if self.num_layers <= 5:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,6 @@ def validate(self) -> None:
391391
Do any additional error checking here.
392392
"""
393393

394-
def update_for_draft_init(self):
395-
"""
396-
Update the config for draft model initialization.
397-
"""
398-
399394
@functools.cached_property
400395
def spec_dec_mode(self):
401396
# spec_dec_mode has more functionality than the raw decoding_mode string.
@@ -450,7 +445,7 @@ def spec_dec_mode(self):
450445
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
451446
return TorchSpeculativeDecodingMode.EAGLE3
452447

453-
@property
448+
@functools.cached_property
454449
def num_capture_layers(self):
455450
"""
456451
Returns the number of layers to capture of the target model.
@@ -461,16 +456,6 @@ def num_capture_layers(self):
461456
return len(self.eagle3_layers_to_capture)
462457
return 3
463458

464-
def update_for_draft_init(self):
465-
"""
466-
Update the config for draft model initialization.
467-
"""
468-
if not self.eagle3_one_model:
469-
num_layers = self.num_eagle_layers
470-
if num_layers is None:
471-
num_layers = 1
472-
self.eagle3_layers_to_capture = set(num_layers - 1)
473-
474459

475460
class UserProvidedDecodingConfig(DecodingBaseConfig):
476461
# Cannot use real type annotations due to circular imports

0 commit comments

Comments
 (0)