Skip to content

Commit e693c30

Browse files
committed
fix last bug and add tests
Signed-off-by: Izzy Putterman <[email protected]>
1 parent b693fdd commit e693c30

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,11 @@ def forward(
219219
# we expect that to happen outside the model definition. This helps us
220220
# avoid data-dependent control flow and gives us better CUDA graph
221221
# coverage.
222+
residual = None
222223
if self.num_layers > 1:
223224
for layer in self.midlayer:
225+
if residual is not None:
226+
hidden_states = hidden_states + residual
224227
hidden_states, residual = layer(position_ids=position_ids,
225228
embeds=inputs_embeds,
226229
hidden_states=hidden_states,

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,108 @@ def test_deepseek_eagle3():
225225
pass
226226

227227

228+
@pytest.mark.parametrize("use_one_model", [
229+
[True],
230+
[False],
231+
])
232+
def test_multi_eagle3(use_one_model: bool):
233+
use_cuda_graph = True
234+
attn_backend = "TRTLLM"
235+
disable_overlap_scheduler = False
236+
enable_block_reuse = False
237+
enable_chunked_prefill = False
238+
239+
# Eagle3 one model works with overlap scheduler and block reuse.
240+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
241+
if total_mem_gb < 150:
242+
pytest.skip("Not enough memory to load target + draft model")
243+
244+
models_path = llm_models_root()
245+
eagle_config = {
246+
'architectures': ['LlamaForCausalLMEagle3'],
247+
'attention_bias': False,
248+
'attention_dropout': 0.0,
249+
'bos_token_id': 128000,
250+
'eos_token_id': [128001, 128008, 128009],
251+
'eagle_config': {
252+
'use_aux_hidden_state': False,
253+
'use_input_layernorm_in_first_layer': True,
254+
'use_last_layernorm': True,
255+
'use_mtp_layernorm': False
256+
},
257+
'head_dim': 128,
258+
'hidden_act': 'silu',
259+
'hidden_size': 4096,
260+
'initializer_range': 0.02,
261+
'intermediate_size': 16384,
262+
'max_position_embeddings': 131072,
263+
'mlp_bias': False,
264+
'model_type': 'llama',
265+
'num_attention_heads': 32,
266+
'num_eagle_features': 1,
267+
'num_hidden_layers': 2,
268+
'num_key_value_heads': 8,
269+
'pretraining_tp': 1,
270+
'rms_norm_eps': 1e-05,
271+
'rope_scaling': {
272+
'factor': 8.0,
273+
'high_freq_factor': 4.0,
274+
'low_freq_factor': 1.0,
275+
'original_max_position_embeddings': 8192,
276+
'rope_type': 'llama3'
277+
},
278+
'rope_theta': 500000.0,
279+
'tie_word_embeddings': False,
280+
'torch_dtype': 'bfloat16',
281+
'transformers_version': '4.52.4',
282+
'use_cache': True,
283+
'vocab_size': 128256,
284+
'draft_vocab_size': 128256
285+
}
286+
with tempfile.TemporaryDirectory() as temp_dir:
287+
eagle_model_dir = Path(temp_dir)
288+
config_path = eagle_model_dir / "config.json"
289+
with config_path.open("w") as f:
290+
json.dump(eagle_config, f, indent=2)
291+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
292+
293+
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
294+
# that ref and spec does not match 100%
295+
max_batch_size = 16
296+
max_draft_len = 3
297+
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
298+
free_gpu_memory_fraction=0.5)
299+
cuda_graph_config = CudaGraphConfig(
300+
batch_sizes=[1]) if use_cuda_graph else None
301+
302+
llm_common_config = dict(
303+
model=target_model_dir,
304+
attn_backend=attn_backend,
305+
disable_overlap_scheduler=disable_overlap_scheduler,
306+
cuda_graph_config=cuda_graph_config,
307+
max_batch_size=max_batch_size,
308+
kv_cache_config=kv_cache_config,
309+
enable_chunked_prefill=enable_chunked_prefill,
310+
)
311+
312+
spec_config = EagleDecodingConfig(
313+
max_draft_len=max_draft_len,
314+
speculative_model_dir=eagle_model_dir,
315+
# Llama 3 does not support one model eagle.
316+
eagle3_one_model=use_one_model,
317+
num_eagle_layers=2,
318+
load_format="dummy")
319+
320+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
321+
322+
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
323+
324+
sampling_params = SamplingParams(max_tokens=32, temperature=0)
325+
for output in llm_spec.generate_async(tok_ids,
326+
sampling_params,
327+
streaming=True):
328+
pass
329+
330+
228331
if __name__ == "__main__":
229332
unittest.main()

0 commit comments

Comments
 (0)