diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index d682da2dc..b2d719bc6 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -206,6 +206,7 @@ def build_options( raise RuntimeError("Not supported") if is_po2_scale: + assert scale_rounding_func_type is not None scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type] options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO options['restrict_value_float_to_int_impl'] = scale_rounding_func diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 98e467708..457877459 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -21,6 +21,8 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseFloatWeightQuantizerBuilder +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseIntWeightQuantizerBuilder from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE @@ -222,7 +224,8 @@ def generate_quantizers( quantize_input_zero_point=False, device=None, weight_kwargs=None, - input_kwargs=None): + input_kwargs=None, + weight_scale_rounding_func_type=None): """ Replace float layers with quant layers in the target model """ @@ -243,8 +246,32 @@ def generate_quantizers( else: input_float_format = {} - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ - weight_param_method][weight_quant_granularity][weight_quant_type] + if weight_quant_granularity == 'per_group': + if weight_quant_format == 'int': + weight_quant = GroupwiseIntWeightQuantizerBuilder( + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = GroupwiseFloatWeightQuantizerBuilder( + exponent_bit_width=weight_float_format['exponent_bit_width'], + mantissa_bit_width=weight_float_format['mantissa_bit_width'], + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] if input_bit_width is not None and input_scale_type == 'no_scale': input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4a87f5a1a..5ef39fffa 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -253,7 +253,9 @@ def main(args): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - device=device) + device=device, + weight_scale_rounding_func_type=args.weight_scale_rounding_func_type + ) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, weight_quant=weight_quant, @@ -400,6 +402,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-scale-rounding-func-type', + type=str, + default=None, + choices=['round', 'ceil', 'floor'], + help='Rounding function to use with Po2 scale. Default: None.') parser.add_argument( '--weight-group-dim', type=int,