@@ -110,16 +110,12 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]):
110
110
weight = load_weight_shard (weights [0 ]['weight' ], module .tp_size ,
111
111
module .tp_rank , module .tp_mode , device )
112
112
113
- if module .has_w4a16_awq or module . has_weight_only_quant :
113
+ if module .has_weight_only_quant :
114
114
# NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm
115
115
# we need to cast the weight to int8 first.
116
- if module .has_w4a16_awq or module .quant_config .layer_quant_mode .is_int4_weight_only (
117
- ):
118
- quant_mode = torch .quint4x2
119
- elif module .quant_config .layer_quant_mode .is_int8_weight_only ():
120
- quant_mode = torch .int8
116
+ weight_dtype , _ = get_weight_dtype_and_id (module )
121
117
weight = preprocess_weights_for_mixed_gemm (
122
- weight .T .to (torch .int8 ).contiguous ().cpu (), quant_mode ,
118
+ weight .T .to (torch .int8 ).contiguous ().cpu (), weight_dtype ,
123
119
torch .float16 ).cuda ().contiguous ()
124
120
125
121
copy_weight (module .weight , weight )
@@ -174,6 +170,27 @@ def load_weights_fused_gate_up_helper(
174
170
return (gate_weight , up_weight )
175
171
176
172
173
+ def get_weight_dtype_and_id (module : Linear ) -> tuple [torch .dtype , int ]:
174
+ """
175
+ Get weight dtype and weight_id for weight only quantization mode.
176
+
177
+ Returns:
178
+ tuple[torch.dtype, int]: (weight_dtype, weight_id) where:
179
+ - weight_dtype: torch.int8 for INT8 weights, torch.quint4x2 for INT4 weights
180
+ - weight_id: 1 for INT8, 2 for INT4 (used for weight packing)
181
+ """
182
+ assert module .quant_config is not None and module .quant_config .layer_quant_mode .is_weight_only (
183
+ ), "This function should only be called when the module has weight-only quantization enabled."
184
+
185
+ if module .quant_config .layer_quant_mode .is_int8_weight_only ():
186
+ return torch .int8 , 1
187
+ elif module .quant_config .layer_quant_mode .is_int4_weight_only ():
188
+ return torch .quint4x2 , 2
189
+ else :
190
+ raise ValueError (
191
+ f"Unsupported quant_mode: { module .quant_config .layer_quant_mode } " )
192
+
193
+
177
194
class LinearMethodBase (ABC ):
178
195
"""
179
196
Base class for all linear methods.
@@ -232,20 +249,6 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
232
249
"""
233
250
raise NotImplementedError
234
251
235
- def _get_weight_dtype_and_id (self ,
236
- module : Linear ) -> tuple [torch .dtype , int ]:
237
- """
238
- get weight dtype and weight_id for weight only quantization mode
239
- """
240
- if module .quant_config .layer_quant_mode .is_int8_weight_only ():
241
- return torch .int8 , 1
242
- elif module .quant_config .layer_quant_mode .is_int4_weight_only ():
243
- return torch .quint4x2 , 2
244
- else :
245
- raise ValueError (
246
- f"Unsupported quant_mode: { module .quant_config .layer_quant_mode } "
247
- )
248
-
249
252
250
253
class UnquantizedLinearMethod (LinearMethodBase ):
251
254
@@ -900,7 +903,7 @@ def create_weights(self, module: Linear, in_features: int,
900
903
out_features : int , bias : bool ,
901
904
dtype : torch .dtype ) -> None :
902
905
903
- _ , weight_id = self . _get_weight_dtype_and_id (module )
906
+ _ , weight_id = get_weight_dtype_and_id (module )
904
907
905
908
# Quantized weights (int4 weights are packed into int8)
906
909
module .weight = Parameter (torch .empty (
@@ -920,12 +923,12 @@ def create_weights(self, module: Linear, in_features: int,
920
923
def apply (self , module : Linear , input : torch .Tensor ,
921
924
bias : Optional [torch .Tensor ]) -> torch .Tensor :
922
925
923
- weight_dtype , _ = self . _get_weight_dtype_and_id (module )
926
+ weight_dtype , _ = get_weight_dtype_and_id (module )
924
927
bias = bias .contiguous () if bias is not None else None
925
928
926
929
output = torch .ops .trtllm .weight_only_quant_gemm (
927
- input . to ( module . dtype ). contiguous () , module .weight , weight_dtype ,
928
- module .weight_scale , module . dtype )
930
+ input , module .weight , weight_dtype , module . weight_scale ,
931
+ module .dtype )
929
932
930
933
return output
931
934
@@ -972,7 +975,7 @@ def load_weights_fused_qkv_linear(self, module: Linear,
972
975
973
976
fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
974
977
975
- weight_dtype , _ = self . _get_weight_dtype_and_id (module )
978
+ weight_dtype , _ = get_weight_dtype_and_id (module )
976
979
fused_weight = preprocess_weights_for_mixed_gemm (
977
980
fused_weight .to (torch .int8 ).T .contiguous ().cpu (), weight_dtype ,
978
981
torch .float16 ).cuda ().contiguous ()
@@ -988,10 +991,10 @@ def load_weights_fused_qkv_linear(self, module: Linear,
988
991
def load_weights_fused_gate_up_linear (self , module : Linear ,
989
992
weights : List [Dict ]) -> None :
990
993
device = torch .device ('cuda' )
991
- weight_dtype , _ = self ._get_weight_dtype_and_id (module )
992
-
994
+ weight_dtype , _ = get_weight_dtype_and_id (module )
993
995
gate_weight , up_weight = load_weights_fused_gate_up_helper (
994
996
module , weights )
997
+
995
998
fused_weight = torch .cat ((gate_weight , up_weight ))
996
999
997
1000
fused_weight = preprocess_weights_for_mixed_gemm (
@@ -1050,8 +1053,7 @@ def apply(self, module: Linear, input: torch.Tensor,
1050
1053
1051
1054
bias = bias .contiguous () if bias is not None else None
1052
1055
1053
- output = torch .ops .trtllm .w4a16_gemm (input .to (
1054
- module .dtype ).contiguous (),
1056
+ output = torch .ops .trtllm .w4a16_gemm (input ,
1055
1057
module .weight ,
1056
1058
module .weight_scale .T .contiguous (),
1057
1059
module .quant_config .group_size ,
0 commit comments