Skip to content

Commit 77b68d9

Browse files
authored
[https://nvbugs/5461712] [fix] Use DG for Qwen3 Linear layers (#8030)
Signed-off-by: Aurelien Chartier <[email protected]>
1 parent c8f98b3 commit 77b68d9

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
import torch
44
from torch import nn
5+
from tqdm import tqdm
56
from transformers import Qwen3Config
67

8+
from tensorrt_llm._utils import is_sm_100f
79
from tensorrt_llm.functional import PositionEmbeddingType
10+
from tensorrt_llm.quantization.utils.fp8_utils import (
11+
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
812

913
from ..attention_backend import AttentionMetadata
1014
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@@ -49,10 +53,6 @@ def __init__(
4953
rope=RopeParams.from_config(config),
5054
)
5155

52-
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
53-
# and https://nvbugspro.nvidia.com/bug/5505402)
54-
disable_deep_gemm = True
55-
5656
super().__init__(
5757
hidden_size=config.hidden_size,
5858
num_attention_heads=config.num_attention_heads,
@@ -65,7 +65,6 @@ def __init__(
6565
dtype=config.torch_dtype,
6666
dense_bias=config.attention_bias,
6767
config=model_config,
68-
disable_deep_gemm=disable_deep_gemm,
6968
)
7069

7170

@@ -86,18 +85,13 @@ def __init__(
8685
self.mapping = model_config.mapping
8786
self.enable_attention_dp = self.mapping.enable_attention_dp
8887

89-
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
90-
# and https://nvbugspro.nvidia.com/bug/5505402)
91-
disable_deep_gemm = True
92-
9388
self.mlp = GatedMLP(
9489
hidden_size=config.hidden_size,
9590
intermediate_size=config.intermediate_size,
9691
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
9792
dtype=config.torch_dtype,
9893
overridden_tp_size=1 if self.enable_attention_dp else None,
9994
config=model_config,
100-
disable_deep_gemm=disable_deep_gemm,
10195
)
10296

10397
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
@@ -223,3 +217,24 @@ def __init__(
223217
Qwen3Model(model_config),
224218
model_config,
225219
)
220+
221+
def post_load_weights(self):
222+
all_named_modules = dict(self.model.named_modules())
223+
for name, module in tqdm(all_named_modules.items(),
224+
desc="Post loading weights"):
225+
if len(module._parameters) <= 0 or name.startswith("draft_model"):
226+
continue
227+
else:
228+
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
229+
) and is_sm_100f() and hasattr(module, "weight_scale"):
230+
weight, weight_scale = resmooth_to_fp8_e8m0(
231+
module.weight, module.weight_scale)
232+
transfromed_scale = transform_sf_into_required_layout(
233+
weight_scale,
234+
mn=weight.shape[0],
235+
k=weight.shape[1],
236+
recipe=(1, 128, 128),
237+
is_sfa=False)
238+
module.weight = nn.Parameter(weight, requires_grad=False)
239+
module.weight_scale = nn.Parameter(transfromed_scale,
240+
requires_grad=False)

0 commit comments

Comments
 (0)