@@ -820,6 +820,191 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
820820 copy_weight (module .alpha , alpha )
821821
822822
823+ class W4A8NVFP4FP8LinearMethod (LinearMethodBase ):
824+
825+ def create_weights (self , module : Linear , in_features : int ,
826+ out_features : int , bias : bool , dtype : torch .dtype ):
827+ module .epilogue_tile_m = 128
828+ module .scaling_vector_size = 32
829+ assert in_features % module .scaling_vector_size == 0 , (
830+ f"in_features { in_features } must be divisible by scaling_vector_size { module .scaling_vector_size } "
831+ )
832+
833+ # Quantized weights
834+ module .weight = Parameter (
835+ torch .empty ([out_features , in_features // 2 ],
836+ dtype = fp4_utils .float4_e2m1x2 ),
837+ requires_grad = False ,
838+ )
839+
840+ # FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
841+ # Padding is required. See computeSFSize in quantization.h
842+ nrows = fp4_utils .pad_up (out_features , 128 )
843+ ncols = fp4_utils .pad_up (in_features // module .scaling_vector_size , 4 )
844+ module .weight_scale = Parameter (torch .empty (
845+ [nrows * ncols ], dtype = fp4_utils .float4_sf_dtype ),
846+ requires_grad = False )
847+
848+ # amax_input / 448
849+ module .input_scale = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
850+ requires_grad = False )
851+ # amax_weight / 448
852+ module .weight_scale_2 = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
853+ requires_grad = False )
854+ # (amax_input * amax_weight) / (448 * 448)
855+ module .alpha = Parameter (torch .empty ([1 ], dtype = torch .float32 ),
856+ requires_grad = False )
857+
858+ if bias :
859+ module .bias = Parameter (torch .empty ((out_features ), dtype = dtype ),
860+ requires_grad = False )
861+ else :
862+ module .register_parameter ("bias" , None )
863+
864+ def apply (self , module : Linear , input : torch .Tensor ,
865+ bias : Optional [torch .Tensor ]):
866+ alpha = module .alpha
867+ if input .dtype != torch .float8_e4m3fn :
868+ if module .input_scale is not None and not module .force_dynamic_quantization :
869+ # Static quantization
870+ fp8_input , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
871+ input , module .input_scale )
872+ else :
873+ # Dynamic quantization
874+ fp8_input , input_scale = torch .ops .tensorrt_llm .quantize_e4m3_per_tensor (
875+ input )
876+ alpha = module .weight_scale_2 * input_scale .to (torch .float32 )
877+
878+ else :
879+ fp8_input = input
880+ output = torch .ops .trtllm .fp4_fp8_gemm_trtllmgen (
881+ fp8_input , module .weight ,
882+ module .weight_scale .view (dtype = torch .float8_e4m3fn ), alpha ,
883+ module .dtype )
884+ if bias is not None :
885+ output = output + bias
886+ return output
887+
888+ def load_weight_scales (
889+ self ,
890+ weights : List [Dict ],
891+ tp_size : int = 1 ,
892+ tp_rank : int = 0 ,
893+ tp_mode : Optional [TensorParallelMode ] = None ,
894+ ):
895+ # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
896+ input_scale = None
897+ weight_scale_2 = None
898+ weight_scale = []
899+
900+ device = torch .device ("cuda" )
901+
902+ for w in weights :
903+ if "input_scale" in w :
904+ if input_scale is None :
905+ input_scale = w ["input_scale" ][...]
906+ else :
907+ assert input_scale == w ["input_scale" ][
908+ ...], "The input_scale should be same for all the weights"
909+ if "weight_scale" in w :
910+ ws = load_weight_shard (w ["weight_scale" ],
911+ tp_size ,
912+ tp_rank ,
913+ tp_mode ,
914+ device = device ).contiguous ()
915+ assert ws .dtype == torch .float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
916+ weight_scale .append (ws .view (fp4_utils .float4_sf_dtype ))
917+ if "weight_scale_2" in w :
918+ if weight_scale_2 is None :
919+ weight_scale_2 = w ["weight_scale_2" ][...]
920+ else :
921+ assert weight_scale_2 == w ["weight_scale_2" ][...], (
922+ "The weight_scale_2 should be same for all the weights" )
923+
924+ # TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
925+ input_scale = input_scale .to (torch .float32 )
926+ weight_scale_2 = weight_scale_2 .to (torch .float32 )
927+ alpha = input_scale * weight_scale_2
928+ return input_scale , weight_scale , weight_scale_2 , alpha
929+
930+ def load_weights_vanilla (self , module : Linear , weights : List [Dict ]) -> None :
931+ # FIXME: this depends on the kernel internals
932+ load_weights_vanilla_helper (
933+ module , weights ,
934+ lambda w : fp4_utils .shuffle_matrix_a (w , module .epilogue_tile_m ))
935+
936+ input_scale , weight_scale , weight_scale_2 , alpha = self .load_weight_scales (
937+ weights ,
938+ tp_size = module .tp_size ,
939+ tp_rank = module .tp_rank ,
940+ tp_mode = module .tp_mode )
941+
942+ assert len (weights ) == 1
943+ weight_scale = weight_scale [0 ]
944+ # Shuffle and Swizzle weight scale
945+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
946+ module .epilogue_tile_m ,
947+ module .scaling_vector_size )
948+ weight_scale = weight_scale .view (dtype = torch .float8_e4m3fn )
949+ copy_weight (module .input_scale , input_scale )
950+ copy_weight (module .weight_scale , weight_scale )
951+ copy_weight (module .weight_scale_2 , weight_scale_2 )
952+ copy_weight (module .alpha , alpha )
953+
954+ def load_weights_fused_qkv_linear (self , module : Linear ,
955+ weights : List [Dict ]) -> None :
956+ q_weight , k_weight , v_weight = load_weights_fused_qkv_helper (
957+ module , weights )
958+
959+ input_scale , weight_scales , weight_scale_2 , alpha = self .load_weight_scales (
960+ weights ,
961+ tp_size = module .tp_size ,
962+ tp_rank = module .tp_rank ,
963+ tp_mode = module .tp_mode )
964+ # Swizzle weight scales after concatenation
965+ weight_scale = torch .cat (weight_scales , 0 )
966+ # Shuffle and Swizzle weight scale
967+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
968+ module .epilogue_tile_m ,
969+ module .scaling_vector_size )
970+ weight_scale = weight_scale .view (dtype = torch .float8_e4m3fn )
971+ copy_weight (module .input_scale , input_scale )
972+ copy_weight (module .weight_scale , weight_scale )
973+ copy_weight (module .weight_scale_2 , weight_scale_2 )
974+ copy_weight (module .alpha , alpha )
975+
976+ fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
977+ fused_weight = fp4_utils .shuffle_matrix_a (fused_weight ,
978+ module .epilogue_tile_m )
979+ copy_weight (module .weight , fused_weight )
980+
981+ def load_weights_fused_gate_up_linear (self , module : Linear ,
982+ weights : List [Dict ]) -> None :
983+ gate_weight , up_weight = load_weights_fused_gate_up_helper (
984+ module , weights )
985+ fused_weight = torch .cat ((gate_weight , up_weight ))
986+ fused_weight = fp4_utils .shuffle_matrix_a (fused_weight ,
987+ module .epilogue_tile_m )
988+ copy_weight (module .weight , fused_weight )
989+
990+ input_scale , weight_scales , weight_scale_2 , alpha = self .load_weight_scales (
991+ weights ,
992+ tp_size = module .tp_size ,
993+ tp_rank = module .tp_rank ,
994+ tp_mode = module .tp_mode )
995+ # Swizzle weight scales after concatenation
996+ weight_scale = torch .cat (weight_scales , 0 )
997+ # Shuffle and Swizzle weight scale
998+ weight_scale = fp4_utils .shuffle_matrix_sf_a (weight_scale ,
999+ module .epilogue_tile_m ,
1000+ module .scaling_vector_size )
1001+ weight_scale = weight_scale .view (dtype = torch .float8_e4m3fn )
1002+ copy_weight (module .input_scale , input_scale )
1003+ copy_weight (module .weight_scale , weight_scale )
1004+ copy_weight (module .weight_scale_2 , weight_scale_2 )
1005+ copy_weight (module .alpha , alpha )
1006+
1007+
8231008class W4A8MXFP4FP8LinearMethod (LinearMethodBase ):
8241009
8251010 def create_weights (self , module : Linear , in_features : int ,
@@ -1480,6 +1665,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
14801665 return FP8BlockScalesLinearMethod ()
14811666 if quant_config .layer_quant_mode .has_nvfp4 ():
14821667 return NVFP4LinearMethod ()
1668+ if quant_config .layer_quant_mode .has_w4a8_nvfp4_fp8 ():
1669+ return W4A8NVFP4FP8LinearMethod ()
14831670 if quant_config .layer_quant_mode .has_w4a8_mxfp4_fp8 ():
14841671 return W4A8MXFP4FP8LinearMethod ()
14851672 if quant_config .layer_quant_mode .is_weight_only (
@@ -1634,6 +1821,12 @@ def has_w4a8_awq(self):
16341821 return self .quant_config is not None and self .quant_config .layer_quant_mode .is_int4_weight_only_per_group (
16351822 ) and self .quant_config .quant_algo == QuantAlgo .W4A8_AWQ
16361823
1824+ @property
1825+ def has_w4a8_nvfp4_fp8 (self ):
1826+ assert self ._weights_created
1827+ return self .quant_config is not None and self .quant_config .layer_quant_mode .has_w4a8_nvfp4_fp8 (
1828+ )
1829+
16371830 @property
16381831 def has_w4a8_mxfp4_fp8 (self ):
16391832 assert self ._weights_created
0 commit comments