Skip to content

Commit 99adbc4

Browse files
committed
feat(eagle3):support qwen3 dense model
Signed-off-by: xq25478 <[email protected]>
1 parent 66f299a commit 99adbc4

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from ..modules.linear import TensorParallelMode
1717
from ..modules.multi_stream_utils import maybe_execute_in_parallel
1818
from ..modules.rms_norm import RMSNorm
19-
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
20-
register_auto_model)
19+
from ..speculative import SpecMetadata
20+
from .modeling_speculative import SpecDecOneEngineForCausalLM
21+
from .modeling_utils import DecoderModel, register_auto_model
2122

2223

2324
class Qwen3Attention(Attention):
@@ -148,6 +149,7 @@ def forward(
148149
attn_metadata: AttentionMetadata,
149150
residual: Optional[torch.Tensor],
150151
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
152+
spec_metadata: Optional[SpecMetadata] = None,
151153
**kwargs,
152154
) -> torch.Tensor:
153155
if residual is None:
@@ -171,6 +173,10 @@ def forward(
171173
hidden_states, residual)
172174
hidden_states = self.mlp(hidden_states)
173175

176+
if spec_metadata is not None:
177+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
178+
hidden_states, residual)
179+
174180
return hidden_states, residual
175181

176182

@@ -207,6 +213,7 @@ def forward(
207213
position_ids: Optional[torch.IntTensor] = None,
208214
inputs_embeds: Optional[torch.FloatTensor] = None,
209215
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
216+
spec_metadata: Optional[SpecMetadata] = None,
210217
**kwargs,
211218
) -> torch.Tensor:
212219
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -227,24 +234,23 @@ def forward(
227234
attn_metadata=attn_metadata,
228235
residual=residual,
229236
mrope_config=mrope_config,
237+
spec_metadata=spec_metadata,
230238
)
231239

232240
hidden_states, _ = self.norm(hidden_states, residual)
233241
return hidden_states
234242

235243

236244
@register_auto_model("Qwen3ForCausalLM")
237-
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
245+
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):
238246

239247
def __init__(
240248
self,
241249
model_config: ModelConfig[Qwen3Config],
242250
):
243251
super().__init__(
244252
Qwen3Model(model_config),
245-
config=model_config,
246-
hidden_size=model_config.pretrained_config.hidden_size,
247-
vocab_size=model_config.pretrained_config.vocab_size,
253+
model_config,
248254
)
249255

250256
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
@@ -256,6 +262,7 @@ def forward(
256262
inputs_embeds: Optional[torch.FloatTensor] = None,
257263
return_context_logits: bool = False,
258264
mrope_config: Optional[dict] = None,
265+
spec_metadata: Optional[SpecMetadata] = None,
259266
**kwargs,
260267
) -> torch.Tensor:
261268
output = self.model(
@@ -264,6 +271,7 @@ def forward(
264271
position_ids=position_ids,
265272
inputs_embeds=inputs_embeds,
266273
mrope_config=mrope_config,
274+
spec_metadata=spec_metadata,
267275
)
268276

269277
return self.logits_processor.forward(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,31 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
16501650
task = MMLU(self.MODEL_NAME)
16511651
task.evaluate(llm)
16521652

1653+
def test_eagle3(self):
1654+
pytorch_config = dict(
1655+
disable_overlap_scheduler=True,
1656+
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
1657+
)
1658+
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
1659+
1660+
eagle_model_dir = f"{llm_models_root()}/qwen3_8b_eagle3"
1661+
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
1662+
1663+
draft_len = 4
1664+
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
1665+
pytorch_weights_path=eagle_model_dir,
1666+
eagle3_one_model=False)
1667+
1668+
llm = LLM(model=target_model_dir,
1669+
**pytorch_config,
1670+
kv_cache_config=kv_cache_config,
1671+
speculative_config=spec_config,
1672+
build_config=None)
1673+
1674+
with llm:
1675+
task = MMLU(self.MODEL_NAME)
1676+
task.evaluate(llm)
1677+
16531678

16541679
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
16551680
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"

0 commit comments

Comments
 (0)