Skip to content

Commit 2fe9cc0

Browse files
authored
chore: remove draft_model_engine from init parameter list of PyExecutor (#6325)
Signed-off-by: junq <[email protected]>
1 parent 1f39a11 commit 2fe9cc0

File tree

3 files changed

+2
-8
lines changed

3 files changed

+2
-8
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def create_py_executor_instance(
411411
executor_config,
412412
ctx_chunk_config,
413413
model_engine,
414-
draft_model_engine,
415414
start_worker,
416415
sampler,
417416
drafter,
@@ -551,7 +550,6 @@ def create_py_executor_instance(
551550
max_draft_len=spec_config.max_draft_len
552551
if spec_config is not None else 0,
553552
kv_cache_transceiver=kv_cache_transceiver,
554-
draft_model_engine=draft_model_engine,
555553
guided_decoder=guided_decoder,
556554
start_worker=start_worker,
557555
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(self,
140140
max_beam_width: int = 1,
141141
max_draft_len: int = 0,
142142
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
143-
draft_model_engine: Optional[ModelEngine] = None,
144143
guided_decoder: Optional[GuidedDecoder] = None,
145144
garbage_collection_gen0_threshold: Optional[int] = None,
146145
start_worker: bool = True):
@@ -161,13 +160,12 @@ def __init__(self,
161160
self.enable_attention_dp = model_engine.enable_attention_dp
162161
self.sampler = sampler
163162
self.drafter = drafter
163+
self.draft_model_engine = getattr(self.drafter, "draft_model_engine",
164+
None)
164165
self.guided_decoder = guided_decoder
165166
self.dist = dist
166167
self.disable_overlap_scheduler = disable_overlap_scheduler
167168

168-
# Draft model for certain spec decode algorithms, e.g. EAGLE3
169-
self.draft_model_engine = draft_model_engine
170-
171169
# enqueue and _fetch_new_requests used data
172170
self.next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests
173171
self.max_beam_width = max_beam_width

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def create_py_executor(
382382
executor_config=executor_config,
383383
ctx_chunk_config=ctx_chunk_config,
384384
model_engine=model_engine,
385-
draft_model_engine=draft_model_engine,
386385
start_worker=False,
387386
sampler=sampler,
388387
drafter=drafter,
@@ -425,7 +424,6 @@ def create_py_executor(
425424
executor_config=executor_config,
426425
ctx_chunk_config=ctx_chunk_config,
427426
model_engine=model_engine,
428-
draft_model_engine=draft_model_engine,
429427
start_worker=False,
430428
sampler=sampler,
431429
drafter=drafter,

0 commit comments

Comments
 (0)