Skip to content

Commit 4e55b3a

Browse files
committed
refine weight-only linear method
Signed-off-by: Yuening Li <[email protected]>
1 parent 941c288 commit 4e55b3a

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,12 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]):
110110
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
111111
module.tp_rank, module.tp_mode, device)
112112

113-
if module.has_w4a16_awq or module.has_weight_only_quant:
113+
if module.has_weight_only_quant:
114114
# NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm
115115
# we need to cast the weight to int8 first.
116-
if module.has_w4a16_awq or module.quant_config.layer_quant_mode.is_int4_weight_only(
117-
):
118-
quant_mode = torch.quint4x2
119-
elif module.quant_config.layer_quant_mode.is_int8_weight_only():
120-
quant_mode = torch.int8
116+
weight_dtype, _ = get_weight_dtype_and_id(module)
121117
weight = preprocess_weights_for_mixed_gemm(
122-
weight.T.to(torch.int8).contiguous().cpu(), quant_mode,
118+
weight.T.to(torch.int8).contiguous().cpu(), weight_dtype,
123119
torch.float16).cuda().contiguous()
124120

125121
copy_weight(module.weight, weight)
@@ -174,6 +170,27 @@ def load_weights_fused_gate_up_helper(
174170
return (gate_weight, up_weight)
175171

176172

173+
def get_weight_dtype_and_id(module: Linear) -> tuple[torch.dtype, int]:
174+
"""
175+
Get weight dtype and weight_id for weight only quantization mode.
176+
177+
Returns:
178+
tuple[torch.dtype, int]: (weight_dtype, weight_id) where:
179+
- weight_dtype: torch.int8 for INT8 weights, torch.quint4x2 for INT4 weights
180+
- weight_id: 1 for INT8, 2 for INT4 (used for weight packing)
181+
"""
182+
assert module.quant_config is not None and module.quant_config.layer_quant_mode.is_weight_only(
183+
), "This function should only be called when the module has weight-only quantization enabled."
184+
185+
if module.quant_config.layer_quant_mode.is_int8_weight_only():
186+
return torch.int8, 1
187+
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
188+
return torch.quint4x2, 2
189+
else:
190+
raise ValueError(
191+
f"Unsupported quant_mode: {module.quant_config.layer_quant_mode}")
192+
193+
177194
class LinearMethodBase(ABC):
178195
"""
179196
Base class for all linear methods.
@@ -232,20 +249,6 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
232249
"""
233250
raise NotImplementedError
234251

235-
def _get_weight_dtype_and_id(self,
236-
module: Linear) -> tuple[torch.dtype, int]:
237-
"""
238-
get weight dtype and weight_id for weight only quantization mode
239-
"""
240-
if module.quant_config.layer_quant_mode.is_int8_weight_only():
241-
return torch.int8, 1
242-
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
243-
return torch.quint4x2, 2
244-
else:
245-
raise ValueError(
246-
f"Unsupported quant_mode: {module.quant_config.layer_quant_mode}"
247-
)
248-
249252

250253
class UnquantizedLinearMethod(LinearMethodBase):
251254

@@ -900,7 +903,7 @@ def create_weights(self, module: Linear, in_features: int,
900903
out_features: int, bias: bool,
901904
dtype: torch.dtype) -> None:
902905

903-
_, weight_id = self._get_weight_dtype_and_id(module)
906+
_, weight_id = get_weight_dtype_and_id(module)
904907

905908
# Quantized weights (int4 weights are packed into int8)
906909
module.weight = Parameter(torch.empty(
@@ -920,12 +923,12 @@ def create_weights(self, module: Linear, in_features: int,
920923
def apply(self, module: Linear, input: torch.Tensor,
921924
bias: Optional[torch.Tensor]) -> torch.Tensor:
922925

923-
weight_dtype, _ = self._get_weight_dtype_and_id(module)
926+
weight_dtype, _ = get_weight_dtype_and_id(module)
924927
bias = bias.contiguous() if bias is not None else None
925928

926929
output = torch.ops.trtllm.weight_only_quant_gemm(
927-
input.to(module.dtype).contiguous(), module.weight, weight_dtype,
928-
module.weight_scale, module.dtype)
930+
input, module.weight, weight_dtype, module.weight_scale,
931+
module.dtype)
929932

930933
return output
931934

@@ -972,7 +975,7 @@ def load_weights_fused_qkv_linear(self, module: Linear,
972975

973976
fused_weight = torch.cat((q_weight, k_weight, v_weight))
974977

975-
weight_dtype, _ = self._get_weight_dtype_and_id(module)
978+
weight_dtype, _ = get_weight_dtype_and_id(module)
976979
fused_weight = preprocess_weights_for_mixed_gemm(
977980
fused_weight.to(torch.int8).T.contiguous().cpu(), weight_dtype,
978981
torch.float16).cuda().contiguous()
@@ -988,10 +991,10 @@ def load_weights_fused_qkv_linear(self, module: Linear,
988991
def load_weights_fused_gate_up_linear(self, module: Linear,
989992
weights: List[Dict]) -> None:
990993
device = torch.device('cuda')
991-
weight_dtype, _ = self._get_weight_dtype_and_id(module)
992-
994+
weight_dtype, _ = get_weight_dtype_and_id(module)
993995
gate_weight, up_weight = load_weights_fused_gate_up_helper(
994996
module, weights)
997+
995998
fused_weight = torch.cat((gate_weight, up_weight))
996999

9971000
fused_weight = preprocess_weights_for_mixed_gemm(
@@ -1050,8 +1053,7 @@ def apply(self, module: Linear, input: torch.Tensor,
10501053

10511054
bias = bias.contiguous() if bias is not None else None
10521055

1053-
output = torch.ops.trtllm.w4a16_gemm(input.to(
1054-
module.dtype).contiguous(),
1056+
output = torch.ops.trtllm.w4a16_gemm(input,
10551057
module.weight,
10561058
module.weight_scale.T.contiguous(),
10571059
module.quant_config.group_size,

0 commit comments

Comments
 (0)