diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 58af032ef..b6ddffa56 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -5,7 +5,6 @@ from itertools import product import os import random -from time import sleep from types import SimpleNamespace import numpy as np @@ -38,6 +37,18 @@ config.IGNORE_MISSING_KEYS = True + +class hashabledict(dict): + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +def unique(sequence): + seen = set() + return [x for x in sequence if not (x in seen or seen.add(x))] + + # Torchvision models with top1 accuracy TORCHVISION_TOP1_MAP = { 'resnet18': 69.758, @@ -68,8 +79,14 @@ } OPTIONS_DEFAULT = { + 'model_name': list(TORCHVISION_TOP1_MAP.keys()), + 'quant_format': ['int'], # Quantization type (INT vs Float) 'target_backend': ['fx'], # Target backend - 'scale_factor_type': ['float'], # Scale factor type + 'scale_factor_type': ['float_scale'], # Scale factor type + 'weight_mantissa_bit_width': [4], + 'weight_exponent_bit_width': [3], + 'act_mantissa_bit_width': [4], + 'act_exponent_bit_width': [3], 'weight_bit_width': [8], # Weight Bit Width 'act_bit_width': [8], # Act bit width 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale @@ -84,7 +101,7 @@ 'learned_round': [False], # Enable/Disable Learned Round 'gptq': [True], # Enable/Disable GPTQ 'gpfq': [False], # Enable/Disable GPFQ - 'gpfq_p': [0.25], # GPFQ P + 'gpfq_p': [0.75], # GPFQ P 'gptq_act_order': [False], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.999], # Activation Quantization Percentile } @@ -106,8 +123,9 @@ parser.add_argument( '--batch-size-validation', default=256, type=int, help='Minibatch size for validation') parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size') -parser.add_argument( - '--options-to-exclude', choices=OPTIONS.keys(), nargs="+", default=[], help='Calibration size') +for option_name, option_value in OPTIONS_DEFAULT.items(): + parser.add_argument( + f'--{option_name}', default=option_value, nargs="+", type=type(option_value[0])) def main(): @@ -116,11 +134,9 @@ def main(): np.random.seed(SEED) torch.manual_seed(SEED) - for option in args.options_to_exclude: - OPTIONS[option] = OPTIONS_DEFAULT[option] - args.gpu = get_gpu_index(args.idx) print("Iter {}, GPU {}".format(args.idx, args.gpu)) + try: ptq_torchvision_models(args) except Exception as E: @@ -129,43 +145,25 @@ def main(): def ptq_torchvision_models(args): # Generate all possible combinations, including invalid ones - # Split stats and mse due to the act_quant_percentile value - if 'stats' in OPTIONS['act_param_method']: - percentile_options = OPTIONS.copy() - percentile_options['act_param_method'] = ['stats'] - else: - percentile_options = None + options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} - if 'mse' in OPTIONS['act_param_method']: - mse_options = OPTIONS.copy() - mse_options['act_param_method'] = ['mse'] - mse_options['act_quant_percentile'] = [None] - else: - mse_options = None - - # Combine MSE and Percentile combinations, if they are defined - combinations = [] - if mse_options is not None: - combinations += list(product(*mse_options.values())) - if percentile_options is not None: - combinations += list(product(*percentile_options.values())) - # Combine the two sets of combinations - # Generate Namespace for each configuration - configs = [ - SimpleNamespace(**{k: v - for k, v in zip(OPTIONS.keys(), combination)}) - for combination in combinations] - # Define which configurations are not valid - configs = list(map(validate_config, configs)) - # Drop invalid configurations - configs = list(config for config in configs if config.is_valid) + combinations = list(product(*options.values())) + + configs = [] + for combination in combinations: + config_namespace = SimpleNamespace( + **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) + config_namespace = validate_config(config_namespace) + if config_namespace.is_valid: + configs.append(hashabledict(**config_namespace.__dict__)) + + configs = unique(configs) if args.idx > len(configs): return - config_namespace = configs[args.idx] - print(config_namespace) + config_namespace = SimpleNamespace(**configs[args.idx]) fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name] # Get model-specific configurations about input shapes and normalization @@ -219,6 +217,7 @@ def ptq_torchvision_models(args): # Define the quantized model quant_model = quantize_model( model, + quant_format=config_namespace.quant_format, backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, weight_bit_width=config_namespace.weight_bit_width, @@ -295,7 +294,7 @@ def validate_config(config_namespace): # Flexml supports only per-tensor scale factors, power of two scale factors if config_namespace.target_backend == 'flexml' and ( config_namespace.weight_quant_granularity == 'per_channel' or - config_namespace.scale_factor_type == 'float32'): + config_namespace.scale_factor_type == 'float_scale'): is_valid = False # Merge bias can be enabled only when graph equalization is enabled if config_namespace.graph_eq_iterations == 0 and config_namespace.graph_eq_merge_bias: @@ -308,15 +307,27 @@ def validate_config(config_namespace): if not config_namespace.gptq and config_namespace.gptq_act_order: is_valid = False - # If GPFQ is disabled, we execute only one configuration for p==0.25 - if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75: - is_valid = False - if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False if config_namespace.act_bit_width < config_namespace.weight_bit_width: is_valid = False + if config_namespace.act_param_method == 'mse': + config_namespace.act_quant_percentile = None + + if not config_namespace.gpfq: + config_namespace.gpfq_p = None + + if config_namespace.quant_format == 'int': + config_namespace.weight_mantissa_bit_width = None + config_namespace.weight_exponent_bit_width = None + config_namespace.act_mantissa_bit_width = None + config_namespace.act_exponent_bit_width = None + + if config_namespace.quant_format == 'float': + config_namespace.act_quant_type = 'sym' + config_namespace.weight_quant_type = 'sym' + config_namespace.is_valid = is_valid return config_namespace diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh index bcc1fec09..f662008a8 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh @@ -1 +1,4 @@ -python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val --options-to-exclude graph_eq_merge_bias graph_eq_iterations +python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \ +--graph_eq_iterations 50 \ +--act_param_method stats mse \ +--act_quant_percentile 99.9 99.99 diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 42e976134..1fc4508ab 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -19,6 +19,10 @@ from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml import brevitas.nn as qnn +from brevitas.quant.experimental.float import Fp8e4m3Act +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -49,45 +53,61 @@ BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None} WEIGHT_QUANT_MAP = { + 'int': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatMSE, + 'asym': ShiftedUint8WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatMSE, + 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPoint}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPointMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPointMSE}},}}, 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE}, - },}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPoint}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPointMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}},}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3WeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3WeightPerChannelFloat}}}}} INPUT_QUANT_MAP = { + 'int': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint}, + }, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE}},}}, 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE}},}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat},}}},} def quantize_model( @@ -100,7 +120,14 @@ def quantize_model( act_quant_percentile, act_quant_type, scale_factor_type, + quant_format, layerwise_first_last_bit_width=8, + layerwise_first_last_mantissa_bit_width=4, + layerwise_first_last_exponent_bit_width=3, + weight_mantissa_bit_width=4, + weight_exponent_bit_width=3, + act_mantissa_bit_width=4, + act_exponent_bit_width=3, weight_narrow_range=False, weight_param_method='stats', act_param_method='stats', @@ -112,20 +139,34 @@ def quantize_model( weight_scale_type = scale_factor_type act_scale_type = scale_factor_type + weight_quant_format = quant_format + act_quant_format = quant_format + weight_quant_granularity = weight_quant_granularity - def bit_width_fn(module, other_bit_width): + def bit_width_fn(module, first_last_bit_width, other_bit_width): if isinstance(module, torch.nn.Conv2d) and module.in_channels == 3: - return layerwise_first_last_bit_width + return first_last_bit_width elif isinstance(module, torch.nn.Linear) and module.out_features == 1000: - return layerwise_first_last_bit_width + return first_last_bit_width else: return other_bit_width weight_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, weight_bit_width) + module, layerwise_first_last_bit_width, weight_bit_width) act_bit_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, act_bit_width) + module, layerwise_first_last_bit_width, act_bit_width) + + weight_mantissa_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( + module, layerwise_first_last_mantissa_bit_width, weight_mantissa_bit_width) + weight_bit_exponent_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( + module, layerwise_first_last_exponent_bit_width, weight_exponent_bit_width) + + act_bit_mantissa_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( + module, layerwise_first_last_mantissa_bit_width, act_mantissa_bit_width) + act_bit_exponent_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( + module, layerwise_first_last_exponent_bit_width, act_exponent_bit_width) + quant_layer_map, quant_layerwise_layer_map, quant_act_map, quant_identity_map = create_quant_maps(dtype=dtype, bias_bit_width=bias_bit_width, weight_bit_width=weight_bit_width_or_lambda, @@ -134,6 +175,12 @@ def bit_width_fn(module, other_bit_width): weight_quant_type=weight_quant_type, weight_quant_granularity=weight_quant_granularity, weight_narrow_range=weight_narrow_range, + weight_quant_format=weight_quant_format, + weight_mantissa_bit_width=weight_mantissa_bit_width_or_lambda, + weight_exponent_bit_width=weight_bit_exponent_width_or_lambda, + act_mantissa_bit_width=act_bit_mantissa_width_or_lambda, + act_exponent_bit_width=act_bit_exponent_width_or_lambda, + act_quant_format=act_quant_format, act_bit_width=act_bit_width_or_lambda, act_scale_type=act_scale_type, act_param_method=act_param_method, @@ -164,6 +211,12 @@ def create_quant_maps( weight_quant_type, weight_quant_granularity, weight_narrow_range, + weight_quant_format, + weight_mantissa_bit_width, + weight_exponent_bit_width, + act_mantissa_bit_width, + act_exponent_bit_width, + act_quant_format, act_bit_width=None, act_scale_type=None, act_param_method=None, @@ -177,19 +230,38 @@ def create_quant_maps( def kwargs_prefix(prefix, weight_kwargs): return {prefix + k: v for k, v in weight_kwargs.items()} + if weight_quant_format == 'float': + weight_float_format = { + 'exponent_bit_width': weight_exponent_bit_width, + 'mantissa_bit_width': weight_mantissa_bit_width} + else: + weight_float_format = {} + + if act_quant_format == 'float': + act_float_format = { + 'exponent_bit_width': act_exponent_bit_width, + 'mantissa_bit_width': act_mantissa_bit_width} + else: + act_float_format = {} + # Retrieve base input, weight, and bias quantizers bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] - weight_quant = WEIGHT_QUANT_MAP[weight_scale_type][weight_param_method][ + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][ weight_quant_granularity][weight_quant_type] + weight_quant = weight_quant.let(**weight_float_format) if act_bit_width is not None: - act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method][act_quant_granularity][ - act_quant_type] + act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + act_quant_granularity][act_quant_type] # Some activations in MHA should always be symmetric - sym_act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method][act_quant_granularity][ - 'sym'] + sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + act_quant_granularity]['sym'] # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method]['per_tensor'][ - act_quant_type] + per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + 'per_tensor'][act_quant_type] + act_quant = act_quant.let(**act_float_format) + sym_act_quant = sym_act_quant.let(**act_float_format) + per_tensor_act_quant = per_tensor_act_quant.let(**act_float_format) + else: act_quant = None sym_act_quant = None