Skip to content

Commit 9096fed

Browse files
byshiuelancelly
authored andcommitted
[doc][ci][Qwen3][nvbugs 5374145] Add Qwen3 235B eagle3 CI (NVIDIA#6477)
Signed-off-by: bhsueh <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 82f574b commit 9096fed

File tree

5 files changed

+65
-7
lines changed

5 files changed

+65
-7
lines changed

examples/models/core/qwen/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ This document shows how to build and run a [Qwen](https://huggingface.co/Qwen) m
2626
- [Serving](#serving)
2727
- [trtllm-serve](#trtllm-serve)
2828
- [Disaggregated Serving](#disaggregated-serving)
29+
- [Eagle3](#eagle3)
2930
- [Dynamo](#dynamo)
3031
- [Notes and Troubleshooting](#notes-and-troubleshooting)
3132
- [Credits](#credits)
@@ -888,6 +889,38 @@ Note that the optimal disaggregated serving configuration (i.e. tp/pp/ep mapping
888889
on the request parameters, the number of concurrent requests and the GPU type. It is recommended to experiment to identify optimal
889890
settings for your specific use case.
890891

892+
#### Eagle3
893+
894+
Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 on Qwen3, you need to set the following arguments when running `trtllm-bench` or `trtllm-serve`:
895+
896+
- `speculative_config.decoding_type: Eagle`
897+
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
898+
- `speculative_config.max_draft_len: 3`
899+
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
900+
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
901+
Specify the path to the Eagle3 draft model (ensure the corresponding draft model weights are prepared).
902+
903+
Currently, there are some limitations when enabling Eagle3:
904+
905+
1. `attention_dp` is not supported. Please disable it or do not set the related flag (it is disabled by default).
906+
2. If you want to use `enable_block_reuse`, the kv cache type of the target model and the draft model must be the same. Since the draft model only supports fp16/bf16, you need to disable `enable_block_reuse` when using fp8 kv cache.
907+
908+
Example `extra-llm-api-config.yml` snippet for Eagle3:
909+
910+
```bash
911+
echo "
912+
enable_attention_dp: false
913+
speculative_config:
914+
decoding_type: Eagle
915+
max_draft_len: 3
916+
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
917+
kv_cache_config:
918+
enable_block_reuse: false
919+
" >> ${path_config}
920+
```
921+
922+
For further details, please refer to [speculative-decoding.md](../../../../docs/source/advanced/speculative-decoding.md)
923+
891924
### Dynamo
892925

893926
NVIDIA Dynamo is a high-throughput low-latency inference framework designed for serving generative AI and reasoning models in multi-node distributed environments.

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ Qwen3/Qwen3-235B-A22B:
8686
- quant_algo: NVFP4
8787
kv_cache_quant_algo: FP8
8888
accuracy: 85.78
89+
- spec_dec_algo: Eagle
90+
quant_algo: NVFP4
91+
kv_cache_quant_algo: FP8
92+
accuracy: 85.78
8993
nvidia/Llama-3_3-Nemotron-Super-49B-v1:
9094
- accuracy: 92.57
9195
- quant_algo: FP8

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ Qwen3/Qwen3-235B-A22B:
170170
- quant_algo: NVFP4
171171
kv_cache_quant_algo: FP8
172172
accuracy: 86
173+
- spec_dec_algo: Eagle
174+
quant_algo: NVFP4
175+
kv_cache_quant_algo: FP8
176+
accuracy: 86
173177
nvidia/Llama-3_3-Nemotron-Super-49B-v1:
174178
- accuracy: 79.43
175179
- quant_algo: FP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,28 +1971,44 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
19711971
@skip_pre_blackwell
19721972
@pytest.mark.skip_less_mpi_world_size(8)
19731973
@pytest.mark.parametrize(
1974-
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend",
1975-
[(8, 1, 8, True, True, True, "CUTLASS"),
1976-
(8, 1, 8, True, True, True, "TRTLLM")],
1977-
ids=["latency_moe_cutlass", "latency_moe_trtllm"],
1974+
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3",
1975+
[
1976+
(8, 1, 8, True, True, True, "CUTLASS", False),
1977+
(8, 1, 8, True, True, True, "TRTLLM", False),
1978+
(8, 1, 8, False, False, False, "TRTLLM", True),
1979+
],
1980+
ids=[
1981+
"latency_moe_cutlass", "latency_moe_trtllm",
1982+
"latency_moe_trtllm_eagle3"
1983+
],
19781984
)
19791985
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
1980-
overlap_scheduler, moe_backend):
1986+
overlap_scheduler, moe_backend, eagle3):
19811987

19821988
pytorch_config = dict(
19831989
disable_overlap_scheduler=not overlap_scheduler,
19841990
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
19851991
moe_config=MoeConfig(backend=moe_backend))
19861992

1987-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1993+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
1994+
enable_block_reuse=not eagle3)
1995+
spec_config = None
1996+
if eagle3:
1997+
spec_config = EagleDecodingConfig(
1998+
max_draft_len=2,
1999+
speculative_model_dir=
2000+
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
2001+
eagle3_one_model=True)
19882002
with LLM(
19892003
f"{llm_models_root()}/Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf",
19902004
tensor_parallel_size=tp_size,
19912005
pipeline_parallel_size=pp_size,
19922006
moe_expert_parallel_size=ep_size,
19932007
**pytorch_config,
19942008
enable_attention_dp=attention_dp,
1995-
kv_cache_config=kv_cache_config) as llm:
2009+
kv_cache_config=kv_cache_config,
2010+
speculative_config=spec_config) as llm:
2011+
19962012
task = MMLU(self.MODEL_NAME)
19972013
task.evaluate(llm)
19982014
task = GSM8K(self.MODEL_NAME)

tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ l0_gb200_multi_nodes:
1818
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] TIMEOUT (180)
1919
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] TIMEOUT (180)
2020
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] TIMEOUT (180)
21+
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] TIMEOUT (180)

0 commit comments

Comments
 (0)