16
16
from ..modules .linear import TensorParallelMode
17
17
from ..modules .multi_stream_utils import maybe_execute_in_parallel
18
18
from ..modules .rms_norm import RMSNorm
19
- from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
20
- register_auto_model )
19
+ from ..speculative import SpecMetadata
20
+ from .modeling_speculative import SpecDecOneEngineForCausalLM
21
+ from .modeling_utils import DecoderModel , register_auto_model
21
22
22
23
23
24
class Qwen3Attention (Attention ):
@@ -148,6 +149,7 @@ def forward(
148
149
attn_metadata : AttentionMetadata ,
149
150
residual : Optional [torch .Tensor ],
150
151
mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
152
+ spec_metadata : Optional [SpecMetadata ] = None ,
151
153
** kwargs ,
152
154
) -> torch .Tensor :
153
155
if residual is None :
@@ -171,6 +173,10 @@ def forward(
171
173
hidden_states , residual )
172
174
hidden_states = self .mlp (hidden_states )
173
175
176
+ if spec_metadata is not None :
177
+ spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
178
+ hidden_states , residual )
179
+
174
180
return hidden_states , residual
175
181
176
182
@@ -207,6 +213,7 @@ def forward(
207
213
position_ids : Optional [torch .IntTensor ] = None ,
208
214
inputs_embeds : Optional [torch .FloatTensor ] = None ,
209
215
mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
216
+ spec_metadata : Optional [SpecMetadata ] = None ,
210
217
** kwargs ,
211
218
) -> torch .Tensor :
212
219
if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -227,48 +234,21 @@ def forward(
227
234
attn_metadata = attn_metadata ,
228
235
residual = residual ,
229
236
mrope_config = mrope_config ,
237
+ spec_metadata = spec_metadata ,
230
238
)
231
239
232
240
hidden_states , _ = self .norm (hidden_states , residual )
233
241
return hidden_states
234
242
235
243
236
244
@register_auto_model ("Qwen3ForCausalLM" )
237
- class Qwen3ForCausalLM (DecoderModelForCausalLM [Qwen3Model , Qwen3Config ]):
245
+ class Qwen3ForCausalLM (SpecDecOneEngineForCausalLM [Qwen3Model , Qwen3Config ]):
238
246
239
247
def __init__ (
240
248
self ,
241
249
model_config : ModelConfig [Qwen3Config ],
242
250
):
243
251
super ().__init__ (
244
252
Qwen3Model (model_config ),
245
- config = model_config ,
246
- hidden_size = model_config .pretrained_config .hidden_size ,
247
- vocab_size = model_config .pretrained_config .vocab_size ,
248
- )
249
-
250
- # NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
251
- def forward (
252
- self ,
253
- attn_metadata : AttentionMetadata ,
254
- input_ids : torch .IntTensor = None ,
255
- position_ids : Optional [torch .IntTensor ] = None ,
256
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
257
- return_context_logits : bool = False ,
258
- mrope_config : Optional [dict ] = None ,
259
- ** kwargs ,
260
- ) -> torch .Tensor :
261
- output = self .model (
262
- input_ids = input_ids ,
263
- attn_metadata = attn_metadata ,
264
- position_ids = position_ids ,
265
- inputs_embeds = inputs_embeds ,
266
- mrope_config = mrope_config ,
267
- )
268
-
269
- return self .logits_processor .forward (
270
- output ,
271
- self .lm_head ,
272
- attn_metadata ,
273
- return_context_logits ,
253
+ model_config ,
274
254
)
0 commit comments