1616from ..modules .linear import TensorParallelMode
1717from ..modules .multi_stream_utils import maybe_execute_in_parallel
1818from ..modules .rms_norm import RMSNorm
19- from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
20- register_auto_model )
19+ from ..speculative import SpecMetadata
20+ from .modeling_speculative import SpecDecOneEngineForCausalLM
21+ from .modeling_utils import DecoderModel , register_auto_model
2122
2223
2324class Qwen3Attention (Attention ):
@@ -148,6 +149,7 @@ def forward(
148149 attn_metadata : AttentionMetadata ,
149150 residual : Optional [torch .Tensor ],
150151 mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
152+ spec_metadata : Optional [SpecMetadata ] = None ,
151153 ** kwargs ,
152154 ) -> torch .Tensor :
153155 if residual is None :
@@ -171,6 +173,10 @@ def forward(
171173 hidden_states , residual )
172174 hidden_states = self .mlp (hidden_states )
173175
176+ if spec_metadata is not None :
177+ spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
178+ hidden_states , residual )
179+
174180 return hidden_states , residual
175181
176182
@@ -207,6 +213,7 @@ def forward(
207213 position_ids : Optional [torch .IntTensor ] = None ,
208214 inputs_embeds : Optional [torch .FloatTensor ] = None ,
209215 mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
216+ spec_metadata : Optional [SpecMetadata ] = None ,
210217 ** kwargs ,
211218 ) -> torch .Tensor :
212219 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -227,48 +234,21 @@ def forward(
227234 attn_metadata = attn_metadata ,
228235 residual = residual ,
229236 mrope_config = mrope_config ,
237+ spec_metadata = spec_metadata ,
230238 )
231239
232240 hidden_states , _ = self .norm (hidden_states , residual )
233241 return hidden_states
234242
235243
236244@register_auto_model ("Qwen3ForCausalLM" )
237- class Qwen3ForCausalLM (DecoderModelForCausalLM [Qwen3Model , Qwen3Config ]):
245+ class Qwen3ForCausalLM (SpecDecOneEngineForCausalLM [Qwen3Model , Qwen3Config ]):
238246
239247 def __init__ (
240248 self ,
241249 model_config : ModelConfig [Qwen3Config ],
242250 ):
243251 super ().__init__ (
244252 Qwen3Model (model_config ),
245- config = model_config ,
246- hidden_size = model_config .pretrained_config .hidden_size ,
247- vocab_size = model_config .pretrained_config .vocab_size ,
248- )
249-
250- # NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
251- def forward (
252- self ,
253- attn_metadata : AttentionMetadata ,
254- input_ids : torch .IntTensor = None ,
255- position_ids : Optional [torch .IntTensor ] = None ,
256- inputs_embeds : Optional [torch .FloatTensor ] = None ,
257- return_context_logits : bool = False ,
258- mrope_config : Optional [dict ] = None ,
259- ** kwargs ,
260- ) -> torch .Tensor :
261- output = self .model (
262- input_ids = input_ids ,
263- attn_metadata = attn_metadata ,
264- position_ids = position_ids ,
265- inputs_embeds = inputs_embeds ,
266- mrope_config = mrope_config ,
267- )
268-
269- return self .logits_processor .forward (
270- output ,
271- self .lm_head ,
272- attn_metadata ,
273- return_context_logits ,
253+ model_config ,
274254 )
0 commit comments