Skip to content

Commit 47fefb3

Browse files
committed
Add W4A8NVFP4FP8LinearMethod
Signed-off-by: Shiyang Chen <[email protected]>
1 parent 8540922 commit 47fefb3

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,191 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
820820
copy_weight(module.alpha, alpha)
821821

822822

823+
class W4A8NVFP4FP8LinearMethod(LinearMethodBase):
824+
825+
def create_weights(self, module: Linear, in_features: int,
826+
out_features: int, bias: bool, dtype: torch.dtype):
827+
module.epilogue_tile_m = 128
828+
module.scaling_vector_size = 32
829+
assert in_features % module.scaling_vector_size == 0, (
830+
f"in_features {in_features} must be divisible by scaling_vector_size {module.scaling_vector_size}"
831+
)
832+
833+
# Quantized weights
834+
module.weight = Parameter(
835+
torch.empty([out_features, in_features // 2],
836+
dtype=fp4_utils.float4_e2m1x2),
837+
requires_grad=False,
838+
)
839+
840+
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
841+
# Padding is required. See computeSFSize in quantization.h
842+
nrows = fp4_utils.pad_up(out_features, 128)
843+
ncols = fp4_utils.pad_up(in_features // module.scaling_vector_size, 4)
844+
module.weight_scale = Parameter(torch.empty(
845+
[nrows * ncols], dtype=fp4_utils.float4_sf_dtype),
846+
requires_grad=False)
847+
848+
# amax_input / 448
849+
module.input_scale = Parameter(torch.empty([1], dtype=torch.float32),
850+
requires_grad=False)
851+
# amax_weight / 448
852+
module.weight_scale_2 = Parameter(torch.empty([1], dtype=torch.float32),
853+
requires_grad=False)
854+
# (amax_input * amax_weight) / (448 * 448)
855+
module.alpha = Parameter(torch.empty([1], dtype=torch.float32),
856+
requires_grad=False)
857+
858+
if bias:
859+
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
860+
requires_grad=False)
861+
else:
862+
module.register_parameter("bias", None)
863+
864+
def apply(self, module: Linear, input: torch.Tensor,
865+
bias: Optional[torch.Tensor]):
866+
alpha = module.alpha
867+
if input.dtype != torch.float8_e4m3fn:
868+
if module.input_scale is not None and not module.force_dynamic_quantization:
869+
# Static quantization
870+
fp8_input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
871+
input, module.input_scale)
872+
else:
873+
# Dynamic quantization
874+
fp8_input, input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
875+
input)
876+
alpha = module.weight_scale_2 * input_scale.to(torch.float32)
877+
878+
else:
879+
fp8_input = input
880+
output = torch.ops.trtllm.fp4_fp8_gemm_trtllmgen(
881+
fp8_input, module.weight,
882+
module.weight_scale.view(dtype=torch.float8_e4m3fn), alpha,
883+
module.dtype)
884+
if bias is not None:
885+
output = output + bias
886+
return output
887+
888+
def load_weight_scales(
889+
self,
890+
weights: List[Dict],
891+
tp_size: int = 1,
892+
tp_rank: int = 0,
893+
tp_mode: Optional[TensorParallelMode] = None,
894+
):
895+
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
896+
input_scale = None
897+
weight_scale_2 = None
898+
weight_scale = []
899+
900+
device = torch.device("cuda")
901+
902+
for w in weights:
903+
if "input_scale" in w:
904+
if input_scale is None:
905+
input_scale = w["input_scale"][...]
906+
else:
907+
assert input_scale == w["input_scale"][
908+
...], "The input_scale should be same for all the weights"
909+
if "weight_scale" in w:
910+
ws = load_weight_shard(w["weight_scale"],
911+
tp_size,
912+
tp_rank,
913+
tp_mode,
914+
device=device).contiguous()
915+
assert ws.dtype == torch.float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
916+
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
917+
if "weight_scale_2" in w:
918+
if weight_scale_2 is None:
919+
weight_scale_2 = w["weight_scale_2"][...]
920+
else:
921+
assert weight_scale_2 == w["weight_scale_2"][...], (
922+
"The weight_scale_2 should be same for all the weights")
923+
924+
# TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
925+
input_scale = input_scale.to(torch.float32)
926+
weight_scale_2 = weight_scale_2.to(torch.float32)
927+
alpha = input_scale * weight_scale_2
928+
return input_scale, weight_scale, weight_scale_2, alpha
929+
930+
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
931+
# FIXME: this depends on the kernel internals
932+
load_weights_vanilla_helper(
933+
module, weights,
934+
lambda w: fp4_utils.shuffle_matrix_a(w, module.epilogue_tile_m))
935+
936+
input_scale, weight_scale, weight_scale_2, alpha = self.load_weight_scales(
937+
weights,
938+
tp_size=module.tp_size,
939+
tp_rank=module.tp_rank,
940+
tp_mode=module.tp_mode)
941+
942+
assert len(weights) == 1
943+
weight_scale = weight_scale[0]
944+
# Shuffle and Swizzle weight scale
945+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
946+
module.epilogue_tile_m,
947+
module.scaling_vector_size)
948+
weight_scale = weight_scale.view(dtype=torch.float8_e4m3fn)
949+
copy_weight(module.input_scale, input_scale)
950+
copy_weight(module.weight_scale, weight_scale)
951+
copy_weight(module.weight_scale_2, weight_scale_2)
952+
copy_weight(module.alpha, alpha)
953+
954+
def load_weights_fused_qkv_linear(self, module: Linear,
955+
weights: List[Dict]) -> None:
956+
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
957+
module, weights)
958+
959+
input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales(
960+
weights,
961+
tp_size=module.tp_size,
962+
tp_rank=module.tp_rank,
963+
tp_mode=module.tp_mode)
964+
# Swizzle weight scales after concatenation
965+
weight_scale = torch.cat(weight_scales, 0)
966+
# Shuffle and Swizzle weight scale
967+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
968+
module.epilogue_tile_m,
969+
module.scaling_vector_size)
970+
weight_scale = weight_scale.view(dtype=torch.float8_e4m3fn)
971+
copy_weight(module.input_scale, input_scale)
972+
copy_weight(module.weight_scale, weight_scale)
973+
copy_weight(module.weight_scale_2, weight_scale_2)
974+
copy_weight(module.alpha, alpha)
975+
976+
fused_weight = torch.cat((q_weight, k_weight, v_weight))
977+
fused_weight = fp4_utils.shuffle_matrix_a(fused_weight,
978+
module.epilogue_tile_m)
979+
copy_weight(module.weight, fused_weight)
980+
981+
def load_weights_fused_gate_up_linear(self, module: Linear,
982+
weights: List[Dict]) -> None:
983+
gate_weight, up_weight = load_weights_fused_gate_up_helper(
984+
module, weights)
985+
fused_weight = torch.cat((gate_weight, up_weight))
986+
fused_weight = fp4_utils.shuffle_matrix_a(fused_weight,
987+
module.epilogue_tile_m)
988+
copy_weight(module.weight, fused_weight)
989+
990+
input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales(
991+
weights,
992+
tp_size=module.tp_size,
993+
tp_rank=module.tp_rank,
994+
tp_mode=module.tp_mode)
995+
# Swizzle weight scales after concatenation
996+
weight_scale = torch.cat(weight_scales, 0)
997+
# Shuffle and Swizzle weight scale
998+
weight_scale = fp4_utils.shuffle_matrix_sf_a(weight_scale,
999+
module.epilogue_tile_m,
1000+
module.scaling_vector_size)
1001+
weight_scale = weight_scale.view(dtype=torch.float8_e4m3fn)
1002+
copy_weight(module.input_scale, input_scale)
1003+
copy_weight(module.weight_scale, weight_scale)
1004+
copy_weight(module.weight_scale_2, weight_scale_2)
1005+
copy_weight(module.alpha, alpha)
1006+
1007+
8231008
class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
8241009

8251010
def create_weights(self, module: Linear, in_features: int,
@@ -1480,6 +1665,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
14801665
return FP8BlockScalesLinearMethod()
14811666
if quant_config.layer_quant_mode.has_nvfp4():
14821667
return NVFP4LinearMethod()
1668+
if quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
1669+
return W4A8NVFP4FP8LinearMethod()
14831670
if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
14841671
return W4A8MXFP4FP8LinearMethod()
14851672
if quant_config.layer_quant_mode.is_weight_only(
@@ -1634,6 +1821,12 @@ def has_w4a8_awq(self):
16341821
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
16351822
) and self.quant_config.quant_algo == QuantAlgo.W4A8_AWQ
16361823

1824+
@property
1825+
def has_w4a8_nvfp4_fp8(self):
1826+
assert self._weights_created
1827+
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8(
1828+
)
1829+
16371830
@property
16381831
def has_w4a8_mxfp4_fp8(self):
16391832
assert self._weights_created

tensorrt_llm/quantization/mode.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
4040
INT8 = auto()
4141
MIXED_PRECISION = auto()
4242
NVFP4 = auto()
43+
W4A8_NVFP4_FP8 = auto()
4344
W4A8_MXFP4_FP8 = auto()
4445
W4A8_MXFP4_MXFP8 = auto()
4546
W4A16_MXFP4 = auto()
@@ -90,6 +91,8 @@ class QuantMode(IntFlag):
9091
# FP4
9192
NVFP4 = auto()
9293
NVFP4_KV_CACHE = auto()
94+
# W4A8 NVFP4
95+
W4A8_NVFP4_FP8 = auto()
9396
# W4A8 MXFP4
9497
W4A8_MXFP4_FP8 = auto()
9598
W4A8_MXFP4_MXFP8 = auto()
@@ -179,6 +182,9 @@ def has_fp8_rowwise(self):
179182
def has_nvfp4(self):
180183
return self._any(self.NVFP4)
181184

185+
def has_w4a8_nvfp4_fp8(self):
186+
return self._any(self.W4A8_NVFP4_FP8)
187+
182188
def has_w4a8_mxfp4_fp8(self):
183189
return self._any(self.W4A8_MXFP4_FP8)
184190

@@ -203,6 +209,7 @@ def has_any_quant(self, exclude_kv_cache: bool = False):
203209
| self.W4A8_QSERVE
204210
| self.FP8_1x128_128x128
205211
| self.NVFP4
212+
| self.W4A8_NVFP4_FP8
206213
| self.W4A8_MXFP4_FP8
207214
| self.W4A16_MXFP4
208215
| self.W4A8_MXFP4_MXFP8)
@@ -240,6 +247,7 @@ def from_description(quantize_weights=False,
240247
use_fp8_block_scales=False,
241248
use_fp8_rowwise=False,
242249
use_nvfp4=False,
250+
use_w4a8_nvfp4_fp8=False,
243251
use_w4a8_qserve=False,
244252
use_w4a8_mxfp4_fp8=False,
245253
use_w4a8_mxfp4_mxfp8=False,
@@ -313,6 +321,9 @@ def raise_error():
313321
if use_nvfp4:
314322
mode = mode | QuantMode.NVFP4
315323

324+
if use_w4a8_nvfp4_fp8:
325+
mode = mode | QuantMode.W4A8_NVFP4_FP8
326+
316327
# W4A8 QServe
317328
if use_w4a8_qserve:
318329
mode = mode | QuantMode.W4A8_QSERVE
@@ -399,6 +410,8 @@ def from_quant_algo(
399410
quant_mode = QuantMode.from_description(use_fp8_block_scales=True)
400411
elif quant_algo == QuantAlgo.NVFP4:
401412
quant_mode = QuantMode.from_description(use_nvfp4=True)
413+
elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8:
414+
quant_mode = QuantMode.from_description(use_w4a8_nvfp4_fp8=True)
402415
elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8:
403416
quant_mode = QuantMode.from_description(use_w4a8_mxfp4_fp8=True)
404417
elif quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:
@@ -437,6 +450,8 @@ def to_dict(self):
437450
self.has_fp8_block_scales(),
438451
'enable_nvfp4':
439452
self.has_nvfp4(),
453+
'enable_w4a8_nvfp4_fp8':
454+
self.has_w4a8_nvfp4_fp8(),
440455
'enable_w4a8_mxfp4_fp8':
441456
self.has_w4a8_mxfp4_fp8(),
442457
'enable_w4a8_mxfp4_mxfp8':

0 commit comments

Comments
 (0)