Skip to content

Commit 28858c8

Browse files
authored
feat(eagle3):support qwen3 dense model (#5879)
Signed-off-by: xq25478 <[email protected]>
1 parent 22d4a8c commit 28858c8

File tree

4 files changed

+39
-32
lines changed

4 files changed

+39
-32
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from ..modules.linear import TensorParallelMode
1717
from ..modules.multi_stream_utils import maybe_execute_in_parallel
1818
from ..modules.rms_norm import RMSNorm
19-
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
20-
register_auto_model)
19+
from ..speculative import SpecMetadata
20+
from .modeling_speculative import SpecDecOneEngineForCausalLM
21+
from .modeling_utils import DecoderModel, register_auto_model
2122

2223

2324
class Qwen3Attention(Attention):
@@ -148,6 +149,7 @@ def forward(
148149
attn_metadata: AttentionMetadata,
149150
residual: Optional[torch.Tensor],
150151
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
152+
spec_metadata: Optional[SpecMetadata] = None,
151153
**kwargs,
152154
) -> torch.Tensor:
153155
if residual is None:
@@ -171,6 +173,10 @@ def forward(
171173
hidden_states, residual)
172174
hidden_states = self.mlp(hidden_states)
173175

176+
if spec_metadata is not None:
177+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
178+
hidden_states, residual)
179+
174180
return hidden_states, residual
175181

176182

@@ -207,6 +213,7 @@ def forward(
207213
position_ids: Optional[torch.IntTensor] = None,
208214
inputs_embeds: Optional[torch.FloatTensor] = None,
209215
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
216+
spec_metadata: Optional[SpecMetadata] = None,
210217
**kwargs,
211218
) -> torch.Tensor:
212219
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -227,48 +234,21 @@ def forward(
227234
attn_metadata=attn_metadata,
228235
residual=residual,
229236
mrope_config=mrope_config,
237+
spec_metadata=spec_metadata,
230238
)
231239

232240
hidden_states, _ = self.norm(hidden_states, residual)
233241
return hidden_states
234242

235243

236244
@register_auto_model("Qwen3ForCausalLM")
237-
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
245+
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):
238246

239247
def __init__(
240248
self,
241249
model_config: ModelConfig[Qwen3Config],
242250
):
243251
super().__init__(
244252
Qwen3Model(model_config),
245-
config=model_config,
246-
hidden_size=model_config.pretrained_config.hidden_size,
247-
vocab_size=model_config.pretrained_config.vocab_size,
248-
)
249-
250-
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
251-
def forward(
252-
self,
253-
attn_metadata: AttentionMetadata,
254-
input_ids: torch.IntTensor = None,
255-
position_ids: Optional[torch.IntTensor] = None,
256-
inputs_embeds: Optional[torch.FloatTensor] = None,
257-
return_context_logits: bool = False,
258-
mrope_config: Optional[dict] = None,
259-
**kwargs,
260-
) -> torch.Tensor:
261-
output = self.model(
262-
input_ids=input_ids,
263-
attn_metadata=attn_metadata,
264-
position_ids=position_ids,
265-
inputs_embeds=inputs_embeds,
266-
mrope_config=mrope_config,
267-
)
268-
269-
return self.logits_processor.forward(
270-
output,
271-
self.lm_head,
272-
attn_metadata,
273-
return_context_logits,
253+
model_config,
274254
)

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ Qwen3/Qwen3-8B:
150150
- quant_algo: FP8_BLOCK_SCALES
151151
accuracy: 76.12
152152
- accuracy: 76.12
153+
- spec_dec_algo: Eagle
154+
accuracy: 76.12
153155
Qwen3/Qwen3-30B-A3B:
154156
- quant_algo: FP8_BLOCK_SCALES
155157
accuracy: 79.53

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,30 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
16581658
task = MMLU(self.MODEL_NAME)
16591659
task.evaluate(llm)
16601660

1661+
def test_eagle3(self):
1662+
pytorch_config = dict(
1663+
disable_overlap_scheduler=True,
1664+
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
1665+
)
1666+
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
1667+
1668+
eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3"
1669+
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
1670+
1671+
draft_len = 4
1672+
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
1673+
speculative_model_dir=eagle_model_dir)
1674+
1675+
llm = LLM(model=target_model_dir,
1676+
**pytorch_config,
1677+
kv_cache_config=kv_cache_config,
1678+
speculative_config=spec_config,
1679+
build_config=None)
1680+
1681+
with llm:
1682+
task = MMLU(self.MODEL_NAME)
1683+
task.evaluate(llm)
1684+
16611685

16621686
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
16631687
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ l0_h100:
4040
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
4141
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
4242
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
43+
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3
4344
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
4445
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
4546
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]

0 commit comments

Comments
 (0)