Skip to content

Commit

Permalink
Integration with llm entrypoing
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 28, 2024
1 parent 87358df commit dca94fa
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def build_options(
raise RuntimeError("Not supported")

if is_po2_scale:
assert scale_rounding_func_type is not None
scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type]
options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO
options['restrict_value_float_to_int_impl'] = scale_rounding_func
Expand Down
33 changes: 30 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseFloatWeightQuantizerBuilder
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseIntWeightQuantizerBuilder
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
Expand Down Expand Up @@ -222,7 +224,8 @@ def generate_quantizers(
quantize_input_zero_point=False,
device=None,
weight_kwargs=None,
input_kwargs=None):
input_kwargs=None,
weight_scale_rounding_func_type=None):
"""
Replace float layers with quant layers in the target model
"""
Expand All @@ -243,8 +246,32 @@ def generate_quantizers(
else:
input_float_format = {}

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
if weight_quant_granularity == 'per_group':
if weight_quant_format == 'int':
weight_quant = GroupwiseIntWeightQuantizerBuilder(
bit_width=weight_bit_width,
scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method,
is_po2_scale=weight_scale_precision == 'po2_scale',
scale_computation_type='parameter_from_stats',
scale_rounding_func_type=weight_scale_rounding_func_type,
group_dim=weight_group_dim,
group_size=weight_group_size,
scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8)
else:
weight_quant = GroupwiseFloatWeightQuantizerBuilder(
exponent_bit_width=weight_float_format['exponent_bit_width'],
mantissa_bit_width=weight_float_format['mantissa_bit_width'],
bit_width=weight_bit_width,
scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method,
is_po2_scale=weight_scale_precision == 'po2_scale',
scale_computation_type='parameter_from_stats',
scale_rounding_func_type=weight_scale_rounding_func_type,
group_dim=weight_group_dim,
group_size=weight_group_size,
scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8)
else:
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]

if input_bit_width is not None and input_scale_type == 'no_scale':
input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][
Expand Down
10 changes: 9 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def main(args):
input_quant_granularity=args.input_quant_granularity,
input_group_size=args.input_group_size,
quantize_input_zero_point=args.quantize_input_zero_point,
device=device)
device=device,
weight_scale_rounding_func_type=args.weight_scale_rounding_func_type
)
layer_map = generate_quant_maps(
linear_input_quant=linear_input_quant,
weight_quant=weight_quant,
Expand Down Expand Up @@ -400,6 +402,12 @@ def parse_args(args):
default='per_group',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
parser.add_argument(
'--weight-scale-rounding-func-type',
type=str,
default=None,
choices=['round', 'ceil', 'floor'],
help='Rounding function to use with Po2 scale. Default: None.')
parser.add_argument(
'--weight-group-dim',
type=int,
Expand Down

0 comments on commit dca94fa

Please sign in to comment.