From 799349e50ac56a4191fa7c9b6d89a00c657bc0a1 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Thu, 21 Sep 2023 17:28:40 +0100 Subject: [PATCH] Feat (examples/llm): add custom float support --- .../llm/llm_quant/quantize.py | 170 ++++++++++++------ .../llm/llm_quant/quantizers.py | 27 ++- src/brevitas_examples/llm/main.py | 40 ++++- 3 files changed, 167 insertions(+), 70 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/llm/llm_quant/quantize.py index f2d51f79a..647dd35a2 100644 --- a/src/brevitas_examples/llm/llm_quant/quantize.py +++ b/src/brevitas_examples/llm/llm_quant/quantize.py @@ -2,12 +2,17 @@ Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause """ +import re from torch import nn from brevitas import nn as qnn from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize +from brevitas.quant.experimental.float import Fp8e4m3Act +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -26,6 +31,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.llm.llm_quant.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat @@ -37,62 +43,82 @@ from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant WEIGHT_QUANT_MAP = { - 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, - 'per_group': { - 'sym': IntWeightSymmetricGroupQuant, 'asym': ShiftedUintWeightAsymmetricGroupQuant}, - }, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE}, - },}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPoint}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPointMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE},},}} - -INPUT_QUANT_MAP = { - 'static': { - 'float': { + 'int': { + 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}, - 'per_row': { - 'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},}, + 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, + 'per_group': { + 'sym': IntWeightSymmetricGroupQuant, + 'asym': ShiftedUintWeightAsymmetricGroupQuant},}, 'mse': { 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}, - 'per_row': { - 'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},}, - 'po2': { + 'sym': Int8WeightPerTensorFloatMSE, + 'asym': ShiftedUint8WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatMSE, + 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'po2_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint},}, + 'sym': Int8WeightPerTensorFixedPoint}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPoint},}, 'mse': { 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE},},}}, - 'dynamic': { - 'float': { + 'sym': Int8WeightPerTensorFixedPointMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPointMSE},},}}, + 'float': { + 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActDynamicPerTensorFloat}, - 'per_row': { - 'sym': Int8ActDynamicPerRowFloat}, + 'sym': Fp8e4m3WeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3WeightPerChannelFloat}, 'per_group': { - 'sym': Int8ActDynamicPerGroupFloat},}}}} + 'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}} + +INPUT_QUANT_MAP = { + 'int': { + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}, + 'per_row': { + 'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}, + 'per_row': { + 'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE},},}}, + 'dynamic': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActDynamicPerTensorFloat}, + 'per_row': { + 'sym': Int8ActDynamicPerRowFloat}, + 'per_group': { + 'sym': Int8ActDynamicPerGroupFloat},}}}}, + 'float': { + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat},}}}, + 'no_scale': { + 'sym': Fp8e4m3Act,}}} def quantize_model( @@ -105,7 +131,9 @@ def quantize_model( weight_quant_granularity, weight_group_size, quantize_weight_zero_point, + weight_quant_format='int', input_bit_width=None, + input_quant_format=None, input_scale_precision=None, input_scale_type=None, input_param_method=None, @@ -119,18 +147,38 @@ def quantize_model( Replace float layers with quant layers in the target model """ # Retrive base input and weight quantizers - weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][ - weight_quant_granularity][weight_quant_type] - if input_bit_width is not None: - input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][ - input_quant_granularity][input_quant_type] + + # match against custom float format + if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format): + weight_float_format = { + 'exponent_bit_width': int(weight_quant_format[1]), + 'mantissa_bit_width': int(weight_quant_format[3])} + weight_quant_format = 'float' + else: + weight_float_format = {} + if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): + input_float_format = { + 'exponent_bit_width': int(input_quant_format[1]), + 'mantissa_bit_width': int(input_quant_format[3])} + input_quant_format = 'float' + else: + input_float_format = {} + + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] + if input_bit_width is not None and input_scale_type == 'no_scale': + input_quant = sym_input_quant = linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][ + input_scale_type][input_quant_type] + elif input_bit_width is not None: + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][input_scale_precision][ + input_param_method][input_quant_granularity][input_quant_type] # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ - input_param_method][input_quant_granularity]['sym'] + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity]['sym'] # Linear layers with 2d input should always be per tensor or per group, as there is no row dimension if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row': - linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ - input_param_method]['per_tensor'][input_quant_type] + linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method]['per_tensor'][input_quant_type] else: assert input_quant_granularity == 'per_group' linear_2d_input_quant = input_quant @@ -145,7 +193,8 @@ def quantize_model( 'bit_width': weight_bit_width, 'narrow_range': False, 'block_size': weight_group_size, - 'quantize_zero_point': quantize_weight_zero_point}) + 'quantize_zero_point': quantize_weight_zero_point}, + **weight_float_format) # weight scale is converted to a standalone parameter # This is done already by default in the per_group quantizer if weight_quant_granularity != 'per_group': @@ -161,7 +210,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype}) + 'dtype': dtype,}, + **input_float_format) if input_scale_type == 'static' and input_quant_granularity == 'per_row': # QuantMHA internally always uses Seq, B, E input_quant = input_quant.let( @@ -188,7 +238,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype}) + 'dtype': dtype}, + **input_float_format) if input_scale_type == 'static' and input_quant_granularity == 'per_row': q_scaled_quant = sym_input_quant.let( **{ @@ -241,7 +292,8 @@ def quantize_model( **{ 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point, - 'dtype': dtype}) + 'dtype': dtype}, + **input_float_format) if input_scale_type == 'dynamic': # Note: this breaks if applied to 3d Linear inputs, # in case standard MHA layers haven't been inserted @@ -265,7 +317,7 @@ def quantize_model( 'in_proj_bias_quant': None, 'softmax_input_quant': None, 'attn_output_weights_quant': attn_output_weights_quant, - 'attn_output_weights_signed': False, + 'attn_output_weights_signed': input_quant_format == 'float', 'q_scaled_quant': q_scaled_quant, 'k_transposed_quant': k_transposed_quant, 'v_quant': v_quant, diff --git a/src/brevitas_examples/llm/llm_quant/quantizers.py b/src/brevitas_examples/llm/llm_quant/quantizers.py index 30aa658c5..28590a0e8 100644 --- a/src/brevitas_examples/llm/llm_quant/quantizers.py +++ b/src/brevitas_examples/llm/llm_quant/quantizers.py @@ -14,8 +14,10 @@ from brevitas.core.stats import NegativePercentileOrZero from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat @@ -25,11 +27,7 @@ from .quant_blocks import * -class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat): - """ - Block / group / vector signed symmetric weight quantizer with float scales. - We inherit from a per-channel quantizer to re-use some underlying machinery. - """ +class WeightSymmetricGroupQuantMixin(ExtendedInjector): @value def expanded_scaling_shape(module, block_size): @@ -69,6 +67,23 @@ def reshaped_scaling_shape(module): block_size = None +class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat): + """ + Block / group / vector signed symmetric int weight quantizer with float scales. + We inherit from a per-channel quantizer to re-use some underlying machinery. + """ + pass + + +class Fp8e4m3WeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, + Fp8e4m3WeightPerChannelFloat): + """ + Block / group / vector signed symmetric e4m3 weight quantizer with float scales. + We inherit from a per-channel quantizer to re-use some underlying machinery. + """ + pass + + class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant): """ Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. @@ -125,7 +140,7 @@ class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat): class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat): """ - Symmetric quantizer with per row dynamic scale. + Symmetric quantizer with per group scale. """ scaling_impl = RuntimeDynamicGroupStatsScaling keepdim = True diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index ffed3146b..5f640dcda 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,6 +4,7 @@ """ import argparse +import re import numpy as np import torch @@ -24,7 +25,21 @@ from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + +class CustomValidator(object): + + def __init__(self, pattern): + self._pattern = re.compile(pattern) + + def __call__(self, value): + if not self._pattern.match(value): + raise argparse.ArgumentTypeError( + "Argument has to match '{}'".format(self._pattern.pattern)) + return value + + parser = argparse.ArgumentParser() +quant_format_validator = CustomValidator(r"int|e[1-8]m[1-8]") parser.add_argument( '--model', type=str, @@ -46,8 +61,8 @@ parser.add_argument( '--weight-scale-precision', type=str, - default='float', - choices=['float', 'po2'], + default='float_scale', + choices=['float_scale', 'po2_scale'], help='Whether scale is a float value or a po2. Default: po2.') parser.add_argument( '--weight-quant-type', @@ -55,6 +70,12 @@ default='asym', choices=['sym', 'asym'], help='Weight quantization type. Default: asym.') +parser.add_argument( + '--weight-quant-format', + type=quant_format_validator, + default='int', + help='Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.' +) parser.add_argument( '--weight-quant-granularity', type=str, @@ -73,6 +94,11 @@ type=int, default=None, help='Input bit width. Default: None (disables input quantization).') +parser.add_argument( + '--input-quant-format', + type=quant_format_validator, + default='int', + help='Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') parser.add_argument( '--input-param-method', type=str, @@ -84,14 +110,14 @@ parser.add_argument( '--input-scale-precision', type=str, - default='float', - choices=['float', 'po2'], + default='float_scale', + choices=['float_scale', 'po2_scale'], help='Whether input scale is a float value or a po2. Default: float.') parser.add_argument( '--input-scale-type', type=str, default='static', - choices=['static', 'dynamic'], + choices=['static', 'dynamic', 'no_scale'], help='Whether input scale is a static value or a dynamic value.') parser.add_argument( '--input-quant-type', @@ -171,6 +197,8 @@ def model_export(model, ref_input, args): def validate(args): if not args.no_quantize: + if args.export_target is not None: + assert args.input_quant_format == 'int', "Only integer quantization supported for export currently." if args.export_target is not None and args.input_bit_width is not None: assert args.input_scale_type == 'static', "Only static scale supported for export currently." if args.export_target == 'sharded_torchmlir_group_weight': @@ -268,6 +296,7 @@ def main(): quantize_model( layers_to_quantize, dtype=dtype, + weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, weight_bit_width=args.weight_bit_width, weight_param_method=args.weight_param_method, @@ -277,6 +306,7 @@ def main(): quantize_weight_zero_point=args.quantize_weight_zero_point, input_bit_width=args.input_bit_width, input_quant_type=args.input_quant_type, + input_quant_format=args.input_quant_format, input_param_method=args.input_param_method, input_scale_precision=args.input_scale_precision, input_scale_type=args.input_scale_type,