Skip to content

Commit b8ade4c

Browse files
IzzyPuttermanWong4j
authored andcommitted
[None][feat] MultiLayer Eagle (NVIDIA#7234)
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 33fa41b commit b8ade4c

File tree

5 files changed

+127
-4
lines changed

5 files changed

+127
-4
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
self.dtype = config.torch_dtype
150150
self.hidden_size = config.hidden_size
151151
self.mapping = model_config.mapping
152+
self.num_layers = model_config.pretrained_config.num_hidden_layers
152153

153154
if hasattr(config, "target_hidden_size"):
154155
self.hidden_size_in = config.target_hidden_size
@@ -162,7 +163,13 @@ def __init__(
162163
bias=getattr(config, "bias", False),
163164
dtype=config.torch_dtype)
164165

165-
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
166+
if self.num_layers > 1:
167+
self.midlayer = nn.ModuleList([
168+
Eagle3DecoderLayer(model_config, start_layer_idx + i)
169+
for i in range(self.num_layers)
170+
])
171+
else:
172+
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
166173

167174
self.norm = RMSNorm(hidden_size=config.hidden_size,
168175
eps=config.rms_norm_eps,
@@ -211,11 +218,22 @@ def forward(
211218
# we expect that to happen outside the model definition. This helps us
212219
# avoid data-dependent control flow and gives us better CUDA graph
213220
# coverage.
214-
hidden_states, residual = self.midlayer(position_ids=position_ids,
221+
residual = None
222+
if self.num_layers > 1:
223+
for layer in self.midlayer:
224+
if residual is not None:
225+
hidden_states = hidden_states + residual
226+
hidden_states, residual = layer(position_ids=position_ids,
215227
embeds=inputs_embeds,
216228
hidden_states=hidden_states,
217229
attn_metadata=attn_metadata,
218230
spec_metadata=spec_metadata)
231+
else:
232+
hidden_states, residual = self.midlayer(position_ids=position_ids,
233+
embeds=inputs_embeds,
234+
hidden_states=hidden_states,
235+
attn_metadata=attn_metadata,
236+
spec_metadata=spec_metadata)
219237

220238
hidden_states, hidden_states_to_save = self.norm(
221239
hidden_states, residual)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,9 @@ def init_meta_tensor(t: torch.Tensor):
10381038

10391039
elif load_format == LoadFormat.DUMMY:
10401040
initialize_dummy_weights(model)
1041+
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
1042+
):
1043+
model.draft_model.load_weights_from_target_model(model)
10411044

10421045
elif load_format == LoadFormat.VISION_ONLY:
10431046
# Vision weights are already loaded within the model.

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class Eagle3SpecMetadata(SpecMetadata):
9595

9696
def __post_init__(self):
9797
if self.layers_to_capture is None:
98-
if self.num_layers == 1:
98+
if self.is_draft_model or self.num_layers == 1:
9999
self.layers_to_capture = (self.num_layers - 1, )
100100
else:
101101
if self.num_layers <= 5:

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def get_num_spec_layers(spec_config):
151151
if spec_config.spec_dec_mode.is_mtp():
152152
return spec_config.num_nextn_predict_layers
153153
if spec_config.spec_dec_mode.is_eagle3_one_model():
154-
return 1
154+
num_eagle_layers = spec_config.num_eagle_layers
155+
return num_eagle_layers if num_eagle_layers is not None else 1
155156
return 0
156157

157158

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,106 @@ def test_deepseek_eagle3():
226226
pass
227227

228228

229+
@pytest.mark.parametrize("use_one_model", [True, False])
230+
def test_multi_eagle3(use_one_model: bool):
231+
use_cuda_graph = True
232+
attn_backend = "TRTLLM"
233+
disable_overlap_scheduler = False
234+
enable_block_reuse = False
235+
enable_chunked_prefill = False
236+
237+
# Eagle3 one model works with overlap scheduler and block reuse.
238+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
239+
if total_mem_gb < 150:
240+
pytest.skip("Not enough memory to load target + draft model")
241+
242+
models_path = llm_models_root()
243+
eagle_config = {
244+
'architectures': ['LlamaForCausalLMEagle3'],
245+
'attention_bias': False,
246+
'attention_dropout': 0.0,
247+
'bos_token_id': 128000,
248+
'eos_token_id': [128001, 128008, 128009],
249+
'eagle_config': {
250+
'use_aux_hidden_state': False,
251+
'use_input_layernorm_in_first_layer': True,
252+
'use_last_layernorm': True,
253+
'use_mtp_layernorm': False
254+
},
255+
'head_dim': 128,
256+
'hidden_act': 'silu',
257+
'hidden_size': 4096,
258+
'initializer_range': 0.02,
259+
'intermediate_size': 16384,
260+
'max_position_embeddings': 131072,
261+
'mlp_bias': False,
262+
'model_type': 'llama',
263+
'num_attention_heads': 32,
264+
'num_eagle_features': 1,
265+
'num_hidden_layers': 2,
266+
'num_key_value_heads': 8,
267+
'pretraining_tp': 1,
268+
'rms_norm_eps': 1e-05,
269+
'rope_scaling': {
270+
'factor': 8.0,
271+
'high_freq_factor': 4.0,
272+
'low_freq_factor': 1.0,
273+
'original_max_position_embeddings': 8192,
274+
'rope_type': 'llama3'
275+
},
276+
'rope_theta': 500000.0,
277+
'tie_word_embeddings': False,
278+
'torch_dtype': 'bfloat16',
279+
'transformers_version': '4.52.4',
280+
'use_cache': True,
281+
'vocab_size': 128256,
282+
'draft_vocab_size': 128256
283+
}
284+
with tempfile.TemporaryDirectory() as temp_dir:
285+
eagle_model_dir = Path(temp_dir)
286+
config_path = eagle_model_dir / "config.json"
287+
with config_path.open("w") as f:
288+
json.dump(eagle_config, f, indent=2)
289+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
290+
291+
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
292+
# that ref and spec does not match 100%
293+
max_batch_size = 16
294+
max_draft_len = 3
295+
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
296+
free_gpu_memory_fraction=0.5)
297+
cuda_graph_config = CudaGraphConfig(
298+
batch_sizes=[1]) if use_cuda_graph else None
299+
300+
llm_common_config = dict(
301+
model=target_model_dir,
302+
attn_backend=attn_backend,
303+
disable_overlap_scheduler=disable_overlap_scheduler,
304+
cuda_graph_config=cuda_graph_config,
305+
max_batch_size=max_batch_size,
306+
kv_cache_config=kv_cache_config,
307+
enable_chunked_prefill=enable_chunked_prefill,
308+
load_format="dummy",
309+
)
310+
311+
spec_config = EagleDecodingConfig(
312+
max_draft_len=max_draft_len,
313+
speculative_model_dir=eagle_model_dir,
314+
# Llama 3 does not support one model eagle.
315+
eagle3_one_model=use_one_model,
316+
num_eagle_layers=2,
317+
load_format="dummy")
318+
319+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
320+
321+
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
322+
323+
sampling_params = SamplingParams(max_tokens=32, temperature=0)
324+
for output in llm_spec.generate_async(tok_ids,
325+
sampling_params,
326+
streaming=True):
327+
pass
328+
329+
229330
if __name__ == "__main__":
230331
unittest.main()

0 commit comments

Comments
 (0)