Skip to content

Commit eebbd11

Browse files
committed
Add W4A8NVFP4FP8LinearMethod
Signed-off-by: Shiyang Chen <[email protected]>
1 parent 8540922 commit eebbd11

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,191 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
820820
copy_weight(module.alpha, alpha)
821821

822822

823+
class W4A8NVFP4FP8LinearMethod(LinearMethodBase):
824+
825+
def create_weights(self, module: Linear, in_features: int,
826+
out_features: int, bias: bool, dtype: torch.dtype):
827+
module.epilogue_tile_m = 128
828+
module.scaling_vector_size = 32
829+
assert in_features % module.scaling_vector_size == 0, (
830+
f"in_features {in_features} must be divisible by scaling_vector_size {module.scaling_vector_size}"
831+
)
832+
833+
# Quantized weights
834+
module.weight = Parameter(
835+
torch.empty([out_features, in_features // 2],
836+
dtype=fp4_utils.float4_e2m1x2),
837+
requires_grad=False,
838+
)
839+
840+
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
841+
# Padding is required. See computeSFSize in quantization.h
842+
nrows = fp4_utils.pad_up(out_features, 128)
843+
ncols = fp4_utils.pad_up(in_features // module.scaling_vector_size, 4)
844+
module.weight_scale = Parameter(torch.empty(
845+
[nrows * ncols], dtype=fp4_utils.float4_sf_dtype),
846+
requires_grad=False)
847+
848+
# amax_input / 448
849+
module.input_scale = Parameter(torch.empty([1], dtype=torch.float32),
850+
requires_grad=False)
851+
# amax_weight / 448
852+
module.weight_scale_2 = Parameter(torch.empty([1], dtype=torch.float32),
853+
requires_grad=False)
854+
# (amax_input * amax_weight) / (448 * 448)
855+
module.alpha = Parameter(torch.empty([1], dtype=torch.float32),
856+
requires_grad=False)
857+
858+
if bias:
859+
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
860+
requires_grad=False)
861+
else:
862+
module.register_parameter("bias", None)
863+
864+
def apply(self, module: Linear, input: torch.Tensor,
865+
bias: Optional[torch.Tensor]):
866+
alpha = module.alpha
867+
if input.dtype != torch.float8_e4m3fn:
868+
if module.input_scale is not None and not module.force_dynamic_quantization:
869+
# Static quantization
870+
fp8_input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
871+
input, module.input_scale)
872+
else:
873+
# Dynamic quantization
874+
fp8_input, input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
875+
input)
876+
alpha = module.weight_scale_2 * input_scale.to(torch.float32)
877+
878+
else:
879+
fp8_input = input
880+
output = torch.ops.trtllm.fp4_fp8_gemm_trtllmgen(
881+
fp8_input,
882+
module.weight,
883+
module.weight_scale,
884+
alpha,
885+
)
886+
if bias is not None:
887+
output = output + bias
888+
return output
889+
890+
def load_weight_scales(
891+
self,
892+
weights: List[Dict],
893+
tp_size: int = 1,
894+
tp_rank: int = 0,
895+
tp_mode: Optional[TensorParallelMode] = None,
896+
):
897+
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
898+
input_scale = None
899+
weight_scale_2 = None
900+
weight_scale = []
901+
902+
device = torch.device("cuda")
903+
904+
for w in weights:
905+
if "input_scale" in w:
906+
if input_scale is None:
907+
input_scale = w["input_scale"][...]
908+
else:
909+
assert input_scale == w["input_scale"][
910+
...], "The input_scale should be same for all the weights"
911+
if "weight_scale" in w:
912+
ws = load_weight_shard(w["weight_scale"],
913+
tp_size,
914+
tp_rank,
915+
tp_mode,
916+
device=device).contiguous()
917+
assert ws.dtype == torch.float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
918+
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
919+
if "weight_scale_2" in w:
920+
if weight_scale_2 is None:
921+
weight_scale_2 = w["weight_scale_2"][...]
922+
else:
923+
assert weight_scale_2 == w["weight_scale_2"][...], (
924+
"The weight_scale_2 should be same for all the weights")
925+
926+
# TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
927+
input_scale = input_scale.to(torch.float32)
928+
weight_scale_2 = weight_scale_2.to(torch.float32)
929+
alpha = input_scale * weight_scale_2
930+
return input_scale, weight_scale, weight_scale_2, alpha
931+
932+
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
933+
# FIXME: this depends on the kernel internals
934+
load_weights_vanilla_helper(
935+
module, weights,
936+
lambda w: fp4_utils.shuffle_matrix_a(w, module.epilogue_tile_m))
937+
938+
input_scale, weight_scale, weight_scale_2, alpha = self.load_weight_scales(
939+
weights,
940+
tp_size=module.tp_size,
941+
tp_rank=module.tp_rank,
942+
tp_mode=module.tp_mode)
943+
944+
assert len(weights) == 1
945+
weight_scale = weight_scale[0]
946+
# Shuffle and Swizzle weight scale
947+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
948+
module.epilogue_tile_m,
949+
module.scaling_vector_size)
950+
951+
copy_weight(module.input_scale, input_scale)
952+
copy_weight(module.weight_scale, weight_scale)
953+
copy_weight(module.weight_scale_2, weight_scale_2)
954+
copy_weight(module.alpha, alpha)
955+
956+
def load_weights_fused_qkv_linear(self, module: Linear,
957+
weights: List[Dict]) -> None:
958+
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
959+
module, weights)
960+
961+
input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales(
962+
weights,
963+
tp_size=module.tp_size,
964+
tp_rank=module.tp_rank,
965+
tp_mode=module.tp_mode)
966+
# Swizzle weight scales after concatenation
967+
weight_scale = torch.cat(weight_scales, 0)
968+
# Shuffle and Swizzle weight scale
969+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
970+
module.epilogue_tile_m,
971+
module.scaling_vector_size)
972+
copy_weight(module.input_scale, input_scale)
973+
copy_weight(module.weight_scale, weight_scale)
974+
copy_weight(module.weight_scale_2, weight_scale_2)
975+
copy_weight(module.alpha, alpha)
976+
977+
fused_weight = torch.cat((q_weight, k_weight, v_weight))
978+
fused_weight = fp4_utils.shuffle_matrix_a(fused_weight,
979+
module.epilogue_tile_m)
980+
copy_weight(module.weight, fused_weight)
981+
982+
def load_weights_fused_gate_up_linear(self, module: Linear,
983+
weights: List[Dict]) -> None:
984+
gate_weight, up_weight = load_weights_fused_gate_up_helper(
985+
module, weights)
986+
fused_weight = torch.cat((gate_weight, up_weight))
987+
fused_weight = fp4_utils.shuffle_matrix_a(fused_weight,
988+
module.epilogue_tile_m)
989+
copy_weight(module.weight, fused_weight)
990+
991+
input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales(
992+
weights,
993+
tp_size=module.tp_size,
994+
tp_rank=module.tp_rank,
995+
tp_mode=module.tp_mode)
996+
# Swizzle weight scales after concatenation
997+
weight_scale = torch.cat(weight_scales, 0)
998+
# Shuffle and Swizzle weight scale
999+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
1000+
module.epilogue_tile_m,
1001+
module.scaling_vector_size)
1002+
copy_weight(module.input_scale, input_scale)
1003+
copy_weight(module.weight_scale, weight_scale)
1004+
copy_weight(module.weight_scale_2, weight_scale_2)
1005+
copy_weight(module.alpha, alpha)
1006+
1007+
8231008
class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
8241009

8251010
def create_weights(self, module: Linear, in_features: int,
@@ -1480,6 +1665,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
14801665
return FP8BlockScalesLinearMethod()
14811666
if quant_config.layer_quant_mode.has_nvfp4():
14821667
return NVFP4LinearMethod()
1668+
if quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
1669+
return W4A8NVFP4FP8LinearMethod()
14831670
if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
14841671
return W4A8MXFP4FP8LinearMethod()
14851672
if quant_config.layer_quant_mode.is_weight_only(

tensorrt_llm/quantization/mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
4040
INT8 = auto()
4141
MIXED_PRECISION = auto()
4242
NVFP4 = auto()
43+
W4A8_NVFP4_FP8 = auto()
4344
W4A8_MXFP4_FP8 = auto()
4445
W4A8_MXFP4_MXFP8 = auto()
4546
W4A16_MXFP4 = auto()
@@ -90,6 +91,8 @@ class QuantMode(IntFlag):
9091
# FP4
9192
NVFP4 = auto()
9293
NVFP4_KV_CACHE = auto()
94+
# W4A8 NVFP4
95+
W4A8_NVFP4_FP8 = auto()
9396
# W4A8 MXFP4
9497
W4A8_MXFP4_FP8 = auto()
9598
W4A8_MXFP4_MXFP8 = auto()
@@ -179,6 +182,9 @@ def has_fp8_rowwise(self):
179182
def has_nvfp4(self):
180183
return self._any(self.NVFP4)
181184

185+
def has_w4a8_nvfp4_fp8(self):
186+
return self._any(self.W4A8_NVFP4_FP8)
187+
182188
def has_w4a8_mxfp4_fp8(self):
183189
return self._any(self.W4A8_MXFP4_FP8)
184190

@@ -203,6 +209,7 @@ def has_any_quant(self, exclude_kv_cache: bool = False):
203209
| self.W4A8_QSERVE
204210
| self.FP8_1x128_128x128
205211
| self.NVFP4
212+
| self.W4A8_NVFP4_FP8
206213
| self.W4A8_MXFP4_FP8
207214
| self.W4A16_MXFP4
208215
| self.W4A8_MXFP4_MXFP8)
@@ -240,6 +247,7 @@ def from_description(quantize_weights=False,
240247
use_fp8_block_scales=False,
241248
use_fp8_rowwise=False,
242249
use_nvfp4=False,
250+
use_w4a8_nvfp4_fp8=False,
243251
use_w4a8_qserve=False,
244252
use_w4a8_mxfp4_fp8=False,
245253
use_w4a8_mxfp4_mxfp8=False,
@@ -313,6 +321,9 @@ def raise_error():
313321
if use_nvfp4:
314322
mode = mode | QuantMode.NVFP4
315323

324+
if use_w4a8_nvfp4_fp8:
325+
mode = mode | QuantMode.W4A8_NVFP4_FP8
326+
316327
# W4A8 QServe
317328
if use_w4a8_qserve:
318329
mode = mode | QuantMode.W4A8_QSERVE
@@ -399,6 +410,8 @@ def from_quant_algo(
399410
quant_mode = QuantMode.from_description(use_fp8_block_scales=True)
400411
elif quant_algo == QuantAlgo.NVFP4:
401412
quant_mode = QuantMode.from_description(use_nvfp4=True)
413+
elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8:
414+
quant_mode = QuantMode.from_description(use_w4a8_nvfp4_fp8=True)
402415
elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8:
403416
quant_mode = QuantMode.from_description(use_w4a8_mxfp4_fp8=True)
404417
elif quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:

0 commit comments

Comments
 (0)