22
33import torch
44from torch import nn
5+ from tqdm import tqdm
56from transformers import Qwen3Config
67
8+ from tensorrt_llm ._utils import is_sm_100f
79from tensorrt_llm .functional import PositionEmbeddingType
10+ from tensorrt_llm .quantization .utils .fp8_utils import (
11+ resmooth_to_fp8_e8m0 , transform_sf_into_required_layout )
812
913from ..attention_backend import AttentionMetadata
1014from ..attention_backend .interface import PositionalEmbeddingParams , RopeParams
@@ -49,10 +53,6 @@ def __init__(
4953 rope = RopeParams .from_config (config ),
5054 )
5155
52- # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
53- # and https://nvbugspro.nvidia.com/bug/5505402)
54- disable_deep_gemm = True
55-
5656 super ().__init__ (
5757 hidden_size = config .hidden_size ,
5858 num_attention_heads = config .num_attention_heads ,
@@ -65,7 +65,6 @@ def __init__(
6565 dtype = config .torch_dtype ,
6666 dense_bias = config .attention_bias ,
6767 config = model_config ,
68- disable_deep_gemm = disable_deep_gemm ,
6968 )
7069
7170
@@ -86,18 +85,13 @@ def __init__(
8685 self .mapping = model_config .mapping
8786 self .enable_attention_dp = self .mapping .enable_attention_dp
8887
89- # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
90- # and https://nvbugspro.nvidia.com/bug/5505402)
91- disable_deep_gemm = True
92-
9388 self .mlp = GatedMLP (
9489 hidden_size = config .hidden_size ,
9590 intermediate_size = config .intermediate_size ,
9691 bias = config .mlp_bias if hasattr (config , "mlp_bias" ) else False ,
9792 dtype = config .torch_dtype ,
9893 overridden_tp_size = 1 if self .enable_attention_dp else None ,
9994 config = model_config ,
100- disable_deep_gemm = disable_deep_gemm ,
10195 )
10296
10397 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
@@ -223,3 +217,24 @@ def __init__(
223217 Qwen3Model (model_config ),
224218 model_config ,
225219 )
220+
221+ def post_load_weights (self ):
222+ all_named_modules = dict (self .model .named_modules ())
223+ for name , module in tqdm (all_named_modules .items (),
224+ desc = "Post loading weights" ):
225+ if len (module ._parameters ) <= 0 or name .startswith ("draft_model" ):
226+ continue
227+ else :
228+ if self .model_config .quant_config .layer_quant_mode .has_fp8_block_scales (
229+ ) and is_sm_100f () and hasattr (module , "weight_scale" ):
230+ weight , weight_scale = resmooth_to_fp8_e8m0 (
231+ module .weight , module .weight_scale )
232+ transfromed_scale = transform_sf_into_required_layout (
233+ weight_scale ,
234+ mn = weight .shape [0 ],
235+ k = weight .shape [1 ],
236+ recipe = (1 , 128 , 128 ),
237+ is_sfa = False )
238+ module .weight = nn .Parameter (weight , requires_grad = False )
239+ module .weight_scale = nn .Parameter (transfromed_scale ,
240+ requires_grad = False )
0 commit comments