Skip to content

Commit 9e57506

Browse files
committed
Draft: MultiLayer Eagle
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 2101d46 commit 9e57506

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
self.dtype = config.torch_dtype
151151
self.hidden_size = config.hidden_size
152152
self.mapping = model_config.mapping
153+
self.num_layers = model_config.pretrained_config.num_hidden_layers
153154

154155
if hasattr(config, "target_hidden_size"):
155156
self.hidden_size_in = config.target_hidden_size
@@ -163,7 +164,13 @@ def __init__(
163164
bias=getattr(config, "bias", False),
164165
dtype=config.torch_dtype)
165166

166-
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
167+
if self.num_layers > 1:
168+
self.midlayer = nn.ModuleList([
169+
Eagle3DecoderLayer(model_config, start_layer_idx + i)
170+
for i in range(self.num_layers)
171+
])
172+
else:
173+
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
167174

168175
self.norm = RMSNorm(hidden_size=config.hidden_size,
169176
eps=config.rms_norm_eps,
@@ -212,11 +219,19 @@ def forward(
212219
# we expect that to happen outside the model definition. This helps us
213220
# avoid data-dependent control flow and gives us better CUDA graph
214221
# coverage.
215-
hidden_states, residual = self.midlayer(position_ids=position_ids,
222+
if self.num_layers > 1:
223+
for layer in self.midlayer:
224+
hidden_states, residual = layer(position_ids=position_ids,
216225
embeds=inputs_embeds,
217226
hidden_states=hidden_states,
218227
attn_metadata=attn_metadata,
219228
spec_metadata=spec_metadata)
229+
else:
230+
hidden_states, residual = self.midlayer(position_ids=position_ids,
231+
embeds=inputs_embeds,
232+
hidden_states=hidden_states,
233+
attn_metadata=attn_metadata,
234+
spec_metadata=spec_metadata)
220235

221236
hidden_states, hidden_states_to_save = self.norm(
222237
hidden_states, residual)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255+
draft_spec_config.update_for_draft_init()
255256
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
256257
if spec_config.load_format == "dummy":
257258
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def get_num_spec_layers(spec_config):
151151
if spec_config.spec_dec_mode.is_mtp():
152152
return spec_config.num_nextn_predict_layers
153153
if spec_config.spec_dec_mode.is_eagle3_one_model():
154-
return 1
154+
num_eagle_layers = spec_config.num_eagle_layers
155+
return num_eagle_layers if num_eagle_layers is not None else 1
155156
return 0
156157

157158

tensorrt_llm/llmapi/llm_args.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,11 @@ def validate(self) -> None:
391391
Do any additional error checking here.
392392
"""
393393

394+
def update_for_draft_init(self):
395+
"""
396+
Update the config for draft model initialization.
397+
"""
398+
394399
@functools.cached_property
395400
def spec_dec_mode(self):
396401
# spec_dec_mode has more functionality than the raw decoding_mode string.
@@ -445,7 +450,7 @@ def spec_dec_mode(self):
445450
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
446451
return TorchSpeculativeDecodingMode.EAGLE3
447452

448-
@functools.cached_property
453+
@property
449454
def num_capture_layers(self):
450455
"""
451456
Returns the number of layers to capture of the target model.
@@ -456,6 +461,16 @@ def num_capture_layers(self):
456461
return len(self.eagle3_layers_to_capture)
457462
return 3
458463

464+
def update_for_draft_init(self):
465+
"""
466+
Update the config for draft model initialization.
467+
"""
468+
if not self.eagle3_one_model:
469+
num_layers = self.num_eagle_layers
470+
if num_layers is None:
471+
num_layers = 1
472+
self.eagle3_layers_to_capture = set(num_layers - 1)
473+
459474

460475
class UserProvidedDecodingConfig(DecodingBaseConfig):
461476
# Cannot use real type annotations due to circular imports

0 commit comments

Comments
 (0)