@@ -820,6 +820,191 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
820820 copy_weight (module .alpha , alpha )
821821
822822
823+ class W4A8NVFP4FP8LinearMethod (LinearMethodBase ):
824+
825+ def create_weights (self , module : Linear , in_features : int ,
826+ out_features : int , bias : bool , dtype : torch .dtype ):
827+ module .epilogue_tile_m = 128
828+ module .scaling_vector_size = 32
829+ assert in_features % module .scaling_vector_size == 0 , (
830+ f"in_features { in_features } must be divisible by scaling_vector_size { module .scaling_vector_size } "
831+ )
832+
833+ # Quantized weights
834+ module .weight = Parameter (
835+ torch .empty ([out_features , in_features // 2 ],
836+ dtype = fp4_utils .float4_e2m1x2 ),
837+ requires_grad = False ,
838+ )
839+
840+ # FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
841+ # Padding is required. See computeSFSize in quantization.h
842+ nrows = fp4_utils .pad_up (out_features , 128 )
843+ ncols = fp4_utils .pad_up (in_features // module .scaling_vector_size , 4 )
844+ module .weight_scale = Parameter (torch .empty (
845+ [nrows * ncols ], dtype = fp4_utils .float4_sf_dtype ),
846+ requires_grad = False )
847+
848+ # amax_input / 448
849+ module .input_scale = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
850+ requires_grad = False )
851+ # amax_weight / 448
852+ module .weight_scale_2 = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
853+ requires_grad = False )
854+ # (amax_input * amax_weight) / (448 * 448)
855+ module .alpha = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
856+ requires_grad = False )
857+
858+ if bias :
859+ module .bias = Parameter (torch .empty ((out_features ), dtype = dtype ),
860+ requires_grad = False )
861+ else :
862+ module .register_parameter ("bias" , None )
863+
864+ def apply (self , module : Linear , input : torch .Tensor ,
865+ bias : Optional [torch .Tensor ]):
866+ alpha = module .alpha
867+ if input .dtype != torch .float8_e4m3fn :
868+ if module .input_scale is not None and not module .force_dynamic_quantization :
869+ # Static quantization
870+ fp8_input , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
871+ input , module .input_scale )
872+ else :
873+ # Dynamic quantization
874+ fp8_input , input_scale = torch .ops .tensorrt_llm .quantize_e4m3_per_tensor (
875+ input )
876+ alpha = module .weight_scale_2 * input_scale .to (torch .float32 )
877+
878+ else :
879+ fp8_input = input
880+ output = torch .ops .trtllm .fp4_fp8_gemm_trtllmgen (
881+ fp8_input ,
882+ module .weight ,
883+ module .weight_scale ,
884+ alpha ,
885+ )
886+ if bias is not None :
887+ output = output + bias
888+ return output
889+
890+ def load_weight_scales (
891+ self ,
892+ weights : List [Dict ],
893+ tp_size : int = 1 ,
894+ tp_rank : int = 0 ,
895+ tp_mode : Optional [TensorParallelMode ] = None ,
896+ ):
897+ # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
898+ input_scale = None
899+ weight_scale_2 = None
900+ weight_scale = []
901+
902+ device = torch .device ("cuda" )
903+
904+ for w in weights :
905+ if "input_scale" in w :
906+ if input_scale is None :
907+ input_scale = w ["input_scale" ][...]
908+ else :
909+ assert input_scale == w ["input_scale" ][
910+ ...], "The input_scale should be same for all the weights"
911+ if "weight_scale" in w :
912+ ws = load_weight_shard (w ["weight_scale" ],
913+ tp_size ,
914+ tp_rank ,
915+ tp_mode ,
916+ device = device ).contiguous ()
917+ assert ws .dtype == torch .float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
918+ weight_scale .append (ws .view (fp4_utils .float4_sf_dtype ))
919+ if "weight_scale_2" in w :
920+ if weight_scale_2 is None :
921+ weight_scale_2 = w ["weight_scale_2" ][...]
922+ else :
923+ assert weight_scale_2 == w ["weight_scale_2" ][...], (
924+ "The weight_scale_2 should be same for all the weights" )
925+
926+ # TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
927+ input_scale = input_scale .to (torch .float32 )
928+ weight_scale_2 = weight_scale_2 .to (torch .float32 )
929+ alpha = input_scale * weight_scale_2
930+ return input_scale , weight_scale , weight_scale_2 , alpha
931+
932+ def load_weights_vanilla (self , module : Linear , weights : List [Dict ]) -> None :
933+ # FIXME: this depends on the kernel internals
934+ load_weights_vanilla_helper (
935+ module , weights ,
936+ lambda w : fp4_utils .shuffle_matrix_a (w , module .epilogue_tile_m ))
937+
938+ input_scale , weight_scale , weight_scale_2 , alpha = self .load_weight_scales (
939+ weights ,
940+ tp_size = module .tp_size ,
941+ tp_rank = module .tp_rank ,
942+ tp_mode = module .tp_mode )
943+
944+ assert len (weights ) == 1
945+ weight_scale = weight_scale [0 ]
946+ # Shuffle and Swizzle weight scale
947+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
948+ module .epilogue_tile_m ,
949+ module .scaling_vector_size )
950+
951+ copy_weight (module .input_scale , input_scale )
952+ copy_weight (module .weight_scale , weight_scale )
953+ copy_weight (module .weight_scale_2 , weight_scale_2 )
954+ copy_weight (module .alpha , alpha )
955+
956+ def load_weights_fused_qkv_linear (self , module : Linear ,
957+ weights : List [Dict ]) -> None :
958+ q_weight , k_weight , v_weight = load_weights_fused_qkv_helper (
959+ module , weights )
960+
961+ input_scale , weight_scales , weight_scale_2 , alpha = self .load_weight_scales (
962+ weights ,
963+ tp_size = module .tp_size ,
964+ tp_rank = module .tp_rank ,
965+ tp_mode = module .tp_mode )
966+ # Swizzle weight scales after concatenation
967+ weight_scale = torch .cat (weight_scales , 0 )
968+ # Shuffle and Swizzle weight scale
969+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
970+ module .epilogue_tile_m ,
971+ module .scaling_vector_size )
972+ copy_weight (module .input_scale , input_scale )
973+ copy_weight (module .weight_scale , weight_scale )
974+ copy_weight (module .weight_scale_2 , weight_scale_2 )
975+ copy_weight (module .alpha , alpha )
976+
977+ fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
978+ fused_weight = fp4_utils .shuffle_matrix_a (fused_weight ,
979+ module .epilogue_tile_m )
980+ copy_weight (module .weight , fused_weight )
981+
982+ def load_weights_fused_gate_up_linear (self , module : Linear ,
983+ weights : List [Dict ]) -> None :
984+ gate_weight , up_weight = load_weights_fused_gate_up_helper (
985+ module , weights )
986+ fused_weight = torch .cat ((gate_weight , up_weight ))
987+ fused_weight = fp4_utils .shuffle_matrix_a (fused_weight ,
988+ module .epilogue_tile_m )
989+ copy_weight (module .weight , fused_weight )
990+
991+ input_scale , weight_scales , weight_scale_2 , alpha = self .load_weight_scales (
992+ weights ,
993+ tp_size = module .tp_size ,
994+ tp_rank = module .tp_rank ,
995+ tp_mode = module .tp_mode )
996+ # Swizzle weight scales after concatenation
997+ weight_scale = torch .cat (weight_scales , 0 )
998+ # Shuffle and Swizzle weight scale
999+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
1000+ module .epilogue_tile_m ,
1001+ module .scaling_vector_size )
1002+ copy_weight (module .input_scale , input_scale )
1003+ copy_weight (module .weight_scale , weight_scale )
1004+ copy_weight (module .weight_scale_2 , weight_scale_2 )
1005+ copy_weight (module .alpha , alpha )
1006+
1007+
8231008class W4A8MXFP4FP8LinearMethod (LinearMethodBase ):
8241009
8251010 def create_weights (self , module : Linear , in_features : int ,
@@ -1480,6 +1665,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
14801665 return FP8BlockScalesLinearMethod ()
14811666 if quant_config .layer_quant_mode .has_nvfp4 ():
14821667 return NVFP4LinearMethod ()
1668+ if quant_config .layer_quant_mode .has_w4a8_nvfp4_fp8 ():
1669+ return W4A8NVFP4FP8LinearMethod ()
14831670 if quant_config .layer_quant_mode .has_w4a8_mxfp4_fp8 ():
14841671 return W4A8MXFP4FP8LinearMethod ()
14851672 if quant_config .layer_quant_mode .is_weight_only (
0 commit comments