From 6c69a9c32707e672c54b48c24d56a2b17fc2e4e3 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 21 Dec 2023 15:05:37 +0000 Subject: [PATCH] Fix (ptq/evaluate): add support for GPFA2Q for evaluate and benchmark --- .../benchmark/ptq_benchmark_torchvision.py | 43 +++++++++++++++---- .../ptq/benchmark/single_command.sh | 4 +- .../imagenet_classification/ptq/ptq_common.py | 9 +++- .../ptq/ptq_evaluate.py | 40 +++++++++++------ 4 files changed, 71 insertions(+), 25 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 75bbd1c1a..f5fe652cb 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -81,6 +81,8 @@ def unique(sequence): 'scale_factor_type': ['float_scale'], # Scale factor type 'weight_mantissa_bit_width': [4], 'weight_exponent_bit_width': [3], + 'weight_narrow_range': [False], + 'layerwise_first_last_bit_width': [8], # Input and weights bit width for first and last layer 'act_mantissa_bit_width': [4], 'act_exponent_bit_width': [3], 'weight_bit_width': [8], # Weight Bit Width @@ -95,10 +97,12 @@ def unique(sequence): 'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization 'act_equalization': ['layerwise'], # Perform Activation Equalization (Smoothquant) 'learned_round': [False], # Enable/Disable Learned Round - 'gptq': [True], # Enable/Disable GPTQ + 'gptq': [False], # Enable/Disable GPTQ 'gpfq': [False], # Enable/Disable GPFQ - 'gpfq_p': [0.75], # GPFQ P - 'gptq_act_order': [False], # Use act_order euristics for GPTQ + 'gpfa2q': [False], # Enable/Disable GPFA2Q + 'gpfq_p': [1.0], # GPFQ P + 'gpxq_act_order': [False], # Use act_order euristics for GPxQ + 'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q 'act_quant_percentile': [99.999], # Activation Quantization Percentile 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible } @@ -221,6 +225,8 @@ def ptq_torchvision_models(args): quant_format=config_namespace.quant_format, backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, + layerwise_first_last_bit_width=config_namespace.layerwise_first_last_bit_width, + weight_narrow_range=config_namespace.weight_narrow_range, weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width, weight_exponent_bit_width=config_namespace.weight_exponent_bit_width, act_mantissa_bit_width=config_namespace.act_mantissa_bit_width, @@ -247,11 +253,25 @@ def ptq_torchvision_models(args): if config_namespace.gpfq: print("Performing GPFQ:") - apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p) + apply_gpfq( + calib_loader, + quant_model, + p=config_namespace.gpfq_p, + act_order=config_namespace.gpxq_act_order) + + if config_namespace.gpfa2q: + print("Performing GPFA2Q:") + apply_gpfq( + calib_loader, + quant_model, + p=config_namespace.gpfq_p, + act_order=config_namespace.gpxq_act_order, + gpfa2q=config_namespace.gpfa2q, + accumulator_bit_width=config_namespace.accumulator_bit_width) if config_namespace.gptq: print("Performing gptq") - apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order) + apply_gptq(calib_loader, quant_model, config_namespace.gpxq_act_order) if config_namespace.learned_round: print("Applying Learned Round:") @@ -309,8 +329,10 @@ def validate_config(config_namespace): if (config_namespace.target_backend == 'fx' or config_namespace.target_backend == 'layerwise') and config_namespace.bias_bit_width == 16: is_valid = False - # If GPTQ is disabled, we do not care about the act_order heuristic - if not config_namespace.gptq and config_namespace.gptq_act_order: + # Only one of GPTQ, GPFQ, or GPA2Q can be enabled, or none + multiple_gpxqs = float(config_namespace.gpfq) + float(config_namespace.gptq) + float( + config_namespace.gpfa2q) + if multiple_gpxqs > 1: is_valid = False if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': @@ -320,9 +342,12 @@ def validate_config(config_namespace): if config_namespace.act_param_method == 'mse': config_namespace.act_quant_percentile = None - - if not config_namespace.gpfq: + # gpfq_p is needed for GPFQ and GPFA2Q + if not config_namespace.gpfq and not config_namespace.gpfa2q: config_namespace.gpfq_p = None + # accumulator bit width is not needed when not GPFA2Q + if not config_namespace.gpfa2q: + config_namespace.accumulator_bit_width = None if config_namespace.quant_format == 'int': config_namespace.weight_mantissa_bit_width = None diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh index c70912fa0..a938999ac 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh @@ -18,8 +18,10 @@ python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/image --act_equalization layerwise \ --learned_round False \ --gptq False \ ---gptq_act_order False \ +--gpxq_act_order False \ --gpfq False \ --gpfq_p None \ +--gpfa2q False \ +--accumulator_bit_width None \ --uint_sym_act_for_unsigned_values False \ --act_quant_percentile None \ diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 41e19c07a..0b626e485 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -472,12 +472,17 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() -def apply_gpfq(calib_loader, model, act_order, p=0.25): +def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gpfq_mode(model, p=p, use_quant_activations=True, act_order=act_order) as gpfq: + with gpfq_mode(model, + p=p, + use_quant_activations=True, + act_order=act_order, + use_gpfa2q=use_gpfa2q, + accumulator_bit_width=accumulator_bit_width) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 01148911e..454bf2488 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -168,10 +168,9 @@ add_bool_arg( parser, 'weight-narrow-range', - default=True, - help='Narrow range for weight quantization (default: enabled)') -parser.add_argument( - '--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)') + default=False, + help='Narrow range for weight quantization (default: disabled)') +parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)') parser.add_argument( '--quant-format', default='int', @@ -211,12 +210,16 @@ default=3, type=int, help='Exponent bit width used with float quantization for activations (default: 3)') +parser.add_argument( + '--accumulator-bit-width', + default=None, + type=int, + help='Accumulator Bit Width for GPFA2Q (default: None)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') +add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') add_bool_arg( - parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') -add_bool_arg( - parser, 'gpfq-act-order', default=False, help='GPFQ Act order heuristic (default: disabled)') + parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') @@ -246,8 +249,8 @@ def main(): f"w{args.weight_bit_width}_" f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" - f"{'gptq_act_order_' if args.gptq_act_order else ''}" - f"{'gpfq_act_order_' if args.gpfq_act_order else ''}" + f"{'gpfa2q_' if args.gpfa2q else ''}" + f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" f"{args.bias_bit_width}bias_" @@ -268,9 +271,10 @@ def main(): f"Weight bit width: {args.weight_bit_width} - " f"GPTQ: {args.gptq} - " f"GPFQ: {args.gpfq} - " + f"GPFA2Q: {args.gpfa2q} - " f"GPFQ P: {args.gpfq_p} - " - f"GPTQ Act Order: {args.gptq_act_order} - " - f"GPFQ Act Order: {args.gpfq_act_order} - " + f"GPxQ Act Order: {args.gpxq_act_order} - " + f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " @@ -367,11 +371,21 @@ def main(): if args.gpfq: print("Performing GPFQ:") - apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpfq_act_order) + apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order) + + if args.gpfa2q: + print("Performing GPFA2Q:") + apply_gpfq( + calib_loader, + quant_model, + p=args.gpfq_p, + act_order=args.gpxq_act_order, + use_gpfa2q=args.gpfa2q, + accumulator_bit_width=args.accumulator_bit_width) if args.gptq: print("Performing GPTQ:") - apply_gptq(calib_loader, quant_model, act_order=args.gptq_act_order) + apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order) if args.learned_round: print("Applying Learned Round:")