16
16
from ..modules .linear import TensorParallelMode
17
17
from ..modules .multi_stream_utils import maybe_execute_in_parallel
18
18
from ..modules .rms_norm import RMSNorm
19
- from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
20
- register_auto_model )
19
+ from ..speculative import SpecMetadata
20
+ from .modeling_speculative import SpecDecOneEngineForCausalLM
21
+ from .modeling_utils import DecoderModel , EagerFusionConfig , register_auto_model
21
22
22
23
23
24
class Qwen3Attention (Attention ):
@@ -141,13 +142,21 @@ def __init__(
141
142
eps = config .rms_norm_eps ,
142
143
dtype = config .torch_dtype )
143
144
145
+ self .fusion_config = EagerFusionConfig ()
146
+ # self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
147
+ # )
148
+ # TODO: re-enable these fusions
149
+ self .fusion_config .PRE_MOE_FUSION = False
150
+ self .fusion_config .POST_MLP_FUSION = False
151
+
144
152
def forward (
145
153
self ,
146
154
position_ids : torch .IntTensor ,
147
155
hidden_states : torch .Tensor ,
148
156
attn_metadata : AttentionMetadata ,
149
157
residual : Optional [torch .Tensor ],
150
158
mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
159
+ spec_metadata : Optional [SpecMetadata ] = None ,
151
160
** kwargs ,
152
161
) -> torch .Tensor :
153
162
if residual is None :
@@ -171,6 +180,10 @@ def forward(
171
180
hidden_states , residual )
172
181
hidden_states = self .mlp (hidden_states )
173
182
183
+ if spec_metadata is not None :
184
+ spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
185
+ hidden_states , residual )
186
+
174
187
return hidden_states , residual
175
188
176
189
@@ -207,6 +220,7 @@ def forward(
207
220
position_ids : Optional [torch .IntTensor ] = None ,
208
221
inputs_embeds : Optional [torch .FloatTensor ] = None ,
209
222
mrope_config : Optional [Tuple [torch .Tensor , int ]] = None ,
223
+ spec_metadata : Optional [SpecMetadata ] = None ,
210
224
** kwargs ,
211
225
) -> torch .Tensor :
212
226
if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -227,24 +241,23 @@ def forward(
227
241
attn_metadata = attn_metadata ,
228
242
residual = residual ,
229
243
mrope_config = mrope_config ,
244
+ spec_metadata = spec_metadata ,
230
245
)
231
246
232
247
hidden_states , _ = self .norm (hidden_states , residual )
233
248
return hidden_states
234
249
235
250
236
251
@register_auto_model ("Qwen3ForCausalLM" )
237
- class Qwen3ForCausalLM (DecoderModelForCausalLM [Qwen3Model , Qwen3Config ]):
252
+ class Qwen3ForCausalLM (SpecDecOneEngineForCausalLM [Qwen3Model , Qwen3Config ]):
238
253
239
254
def __init__ (
240
255
self ,
241
256
model_config : ModelConfig [Qwen3Config ],
242
257
):
243
258
super ().__init__ (
244
259
Qwen3Model (model_config ),
245
- config = model_config ,
246
- hidden_size = model_config .pretrained_config .hidden_size ,
247
- vocab_size = model_config .pretrained_config .vocab_size ,
260
+ model_config ,
248
261
)
249
262
250
263
# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
@@ -256,6 +269,7 @@ def forward(
256
269
inputs_embeds : Optional [torch .FloatTensor ] = None ,
257
270
return_context_logits : bool = False ,
258
271
mrope_config : Optional [dict ] = None ,
272
+ spec_metadata : Optional [SpecMetadata ] = None ,
259
273
** kwargs ,
260
274
) -> torch .Tensor :
261
275
output = self .model (
@@ -264,6 +278,7 @@ def forward(
264
278
position_ids = position_ids ,
265
279
inputs_embeds = inputs_embeds ,
266
280
mrope_config = mrope_config ,
281
+ spec_metadata = spec_metadata ,
267
282
)
268
283
269
284
return self .logits_processor .forward (
0 commit comments