Skip to content

Commit ced83ff

Browse files
committed
feat(eagle3):support qwen3 dense model
Signed-off-by: xq25478 <[email protected]>
1 parent 66f299a commit ced83ff

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from ..modules.linear import TensorParallelMode
1717
from ..modules.multi_stream_utils import maybe_execute_in_parallel
1818
from ..modules.rms_norm import RMSNorm
19-
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
20-
register_auto_model)
19+
from ..speculative import SpecMetadata
20+
from .modeling_speculative import SpecDecOneEngineForCausalLM
21+
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
2122

2223

2324
class Qwen3Attention(Attention):
@@ -141,13 +142,21 @@ def __init__(
141142
eps=config.rms_norm_eps,
142143
dtype=config.torch_dtype)
143144

145+
self.fusion_config = EagerFusionConfig()
146+
# self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
147+
# )
148+
# TODO: re-enable these fusions
149+
self.fusion_config.PRE_MOE_FUSION = False
150+
self.fusion_config.POST_MLP_FUSION = False
151+
144152
def forward(
145153
self,
146154
position_ids: torch.IntTensor,
147155
hidden_states: torch.Tensor,
148156
attn_metadata: AttentionMetadata,
149157
residual: Optional[torch.Tensor],
150158
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
159+
spec_metadata: Optional[SpecMetadata] = None,
151160
**kwargs,
152161
) -> torch.Tensor:
153162
if residual is None:
@@ -171,6 +180,10 @@ def forward(
171180
hidden_states, residual)
172181
hidden_states = self.mlp(hidden_states)
173182

183+
if spec_metadata is not None:
184+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
185+
hidden_states, residual)
186+
174187
return hidden_states, residual
175188

176189

@@ -207,6 +220,7 @@ def forward(
207220
position_ids: Optional[torch.IntTensor] = None,
208221
inputs_embeds: Optional[torch.FloatTensor] = None,
209222
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
223+
spec_metadata: Optional[SpecMetadata] = None,
210224
**kwargs,
211225
) -> torch.Tensor:
212226
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -227,24 +241,23 @@ def forward(
227241
attn_metadata=attn_metadata,
228242
residual=residual,
229243
mrope_config=mrope_config,
244+
spec_metadata=spec_metadata,
230245
)
231246

232247
hidden_states, _ = self.norm(hidden_states, residual)
233248
return hidden_states
234249

235250

236251
@register_auto_model("Qwen3ForCausalLM")
237-
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
252+
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):
238253

239254
def __init__(
240255
self,
241256
model_config: ModelConfig[Qwen3Config],
242257
):
243258
super().__init__(
244259
Qwen3Model(model_config),
245-
config=model_config,
246-
hidden_size=model_config.pretrained_config.hidden_size,
247-
vocab_size=model_config.pretrained_config.vocab_size,
260+
model_config,
248261
)
249262

250263
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
@@ -256,6 +269,7 @@ def forward(
256269
inputs_embeds: Optional[torch.FloatTensor] = None,
257270
return_context_logits: bool = False,
258271
mrope_config: Optional[dict] = None,
272+
spec_metadata: Optional[SpecMetadata] = None,
259273
**kwargs,
260274
) -> torch.Tensor:
261275
output = self.model(
@@ -264,6 +278,7 @@ def forward(
264278
position_ids=position_ids,
265279
inputs_embeds=inputs_embeds,
266280
mrope_config=mrope_config,
281+
spec_metadata=spec_metadata,
267282
)
268283

269284
return self.logits_processor.forward(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,31 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
16501650
task = MMLU(self.MODEL_NAME)
16511651
task.evaluate(llm)
16521652

1653+
def test_eagle3(self):
1654+
pytorch_config = dict(
1655+
disable_overlap_scheduler=True,
1656+
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
1657+
)
1658+
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
1659+
1660+
eagle_model_dir = f"{llm_models_root()}/qwen3_8b_eagle3"
1661+
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
1662+
1663+
draft_len = 4
1664+
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
1665+
pytorch_weights_path=eagle_model_dir,
1666+
eagle3_one_model=False)
1667+
1668+
llm = LLM(model=target_model_dir,
1669+
**pytorch_config,
1670+
kv_cache_config=kv_cache_config,
1671+
speculative_config=spec_config,
1672+
build_config=None)
1673+
1674+
with llm:
1675+
task = MMLU(self.MODEL_NAME)
1676+
task.evaluate(llm)
1677+
16531678

16541679
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
16551680
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"

0 commit comments

Comments
 (0)