From 3464ec790c482e5b1b921baa020ce71fb078efc0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 May 2024 16:24:15 +0200 Subject: [PATCH] Feat (examples/ptq): support for dynamic act quant (#935) --- src/brevitas/quant_tensor/torch_handler.py | 14 +++- .../common/generative/quant_blocks.py | 1 + .../imagenet_classification/ptq/README.md | 4 + .../imagenet_classification/ptq/ptq_common.py | 78 +++++++++++++------ .../ptq/ptq_evaluate.py | 13 +++- 5 files changed, 80 insertions(+), 30 deletions(-) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index fcbe35a42..79934864f 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -3,12 +3,13 @@ import functools import math +from typing import Callable import warnings import torch +from torch import Tensor import torch.nn.functional as F -import brevitas from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste from brevitas.utils.torch_utils import compute_channel_view_shape @@ -358,11 +359,16 @@ def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): training=training) -def quant_output_scale_impl(fn, inp, quant_input_scale, quant_weight_scale): +def quant_output_scale_impl( + fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): channel_dim = -1 if fn == F.linear else 1 output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) - output_scale = quant_weight_scale.view(output_scale_shape) - output_scale = output_scale * quant_input_scale.view(output_scale_shape) + + quant_weight_scale = quant_weight_scale.view(output_scale_shape) + if len(quant_input_scale.shape) == 0: + quant_input_scale = quant_input_scale.view(output_scale_shape) + + output_scale = quant_weight_scale * quant_input_scale return output_scale diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index e403deecf..40516111f 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -106,6 +106,7 @@ def forward(self, x) -> Tensor: shape = x.shape x = self.scaling_stats_input_view_shape_impl(x) x = self.stats_impl(x) + x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index 19aa4054c..5387014e9 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -84,6 +84,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--weight-quant-calibration-type {stats,mse}] [--act-equalization {fx,layerwise,None}] [--act-quant-calibration-type {stats,mse}] + [--act-scale-computation-type {static,dynamic}] [--graph-eq-iterations GRAPH_EQ_ITERATIONS] [--learned-round-iters LEARNED_ROUND_ITERS] [--learned-round-lr LEARNED_ROUND_LR] @@ -184,6 +185,9 @@ options: --act-quant-calibration-type {stats,mse} Activation quantization calibration type (default: stats) + --act-scale-computation-type {static,dynamic} + Activation quantization scale computation type + (default: static) --graph-eq-iterations GRAPH_EQ_ITERATIONS Numbers of iterations for graph equalization (default: 20) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 95153bd0e..2ac2af250 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -8,6 +8,7 @@ import torch.backends.cudnn as cudnn from tqdm import tqdm +from brevitas.core.function_wrapper.shape import OverBatchOverTensorView from brevitas.core.scaling.standalone import ParameterFromStatsFromParameterScaling from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.calibrate import bias_correction_mode @@ -49,10 +50,28 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers + +# Every element of the Batch will have its own scale factor and zero point +class CNNShiftedUint8DynamicActPerTensorFloat(ShiftedUint8DynamicActPerTensorFloat): + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_permute_dims = None + stats_reduce_dim = 1 + dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))]) + + +class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_permute_dims = None + stats_reduce_dim = 1 + dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))]) + + QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize, 'flexml': quantize_flexml} BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None} @@ -98,21 +117,29 @@ INPUT_QUANT_MAP = { 'int': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, - 'po2_scale': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint}, - }, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE}},}}, + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, + 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint, + 'asym': ShiftedUint8ActPerTensorFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE}}}}, + 'dynamic': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': CNNInt8DynamicActPerTensorFloat, + 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}}, 'float': { 'float_scale': { 'stats': { @@ -146,6 +173,7 @@ def quantize_model( act_param_method='stats', weight_quant_type='sym', act_quant_granularity='per_tensor', + act_scale_computation_type='dynamic', uint_sym_act_for_unsigned_values=True, dtype=torch.float32, device='cpu'): @@ -165,8 +193,10 @@ def quantize_model( weight_mantissa_bit_width, weight_exponent_bit_width, act_mantissa_bit_width, - act_exponent_bit_width, - ) + act_exponent_bit_width) + + if act_scale_computation_type == 'dynamic': + assert bias_bit_width is None, "Bias quantization is not supported with dynamic activation quantization" weight_quant_format = quant_format act_quant_format = quant_format @@ -253,6 +283,7 @@ def layerwise_bit_width_fn_weight(module): act_quant_type=act_quant_type, act_quant_granularity=act_quant_granularity, act_quant_percentile=act_quant_percentile, + act_scale_computation_type=act_scale_computation_type, **weight_bit_width_dict, **act_bit_width_dict) @@ -288,6 +319,7 @@ def create_quant_maps( act_exponent_bit_width=None, act_bit_width=None, act_scale_type=None, + act_scale_computation_type=None, act_param_method=None, act_quant_type=None, act_quant_granularity=None, @@ -317,14 +349,14 @@ def kwargs_prefix(prefix, weight_kwargs): weight_quant = weight_quant.let(**weight_bit_width_dict) if act_bit_width is not None: - act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - act_quant_granularity][act_quant_type] + act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][act_scale_type][ + act_param_method][act_quant_granularity][act_quant_type] # Some activations in MHA should always be symmetric - sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - act_quant_granularity]['sym'] + sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ + act_scale_type][act_param_method][act_quant_granularity]['sym'] # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - 'per_tensor'][act_quant_type] + per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ + act_scale_type][act_param_method]['per_tensor'][act_quant_type] act_quant = act_quant.let(**act_bit_width_dict) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 1f7c06a2b..377e705ab 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -138,6 +138,11 @@ def parse_type(v, default_type): default='stats', choices=['stats', 'mse'], help='Activation quantization calibration type (default: stats)') +parser.add_argument( + '--act-scale-computation-type', + default='static', + choices=['static', 'dynamic'], + help='Activation quantization scale computation type (default: static)') parser.add_argument( '--graph-eq-iterations', default=20, @@ -411,11 +416,13 @@ def main(): weight_exponent_bit_width=args.weight_exponent_bit_width, act_mantissa_bit_width=args.act_mantissa_bit_width, act_exponent_bit_width=args.act_exponent_bit_width, + act_scale_computation_type=args.act_scale_computation_type, uint_sym_act_for_unsigned_values=args.uint_sym_act_for_unsigned_values) - # Calibrate the quant_model on the calibration dataloader - print("Starting activation calibration:") - calibrate(calib_loader, quant_model) + if args.act_scale_computation_type == 'static': + # Calibrate the quant_model on the calibration dataloader + print("Starting activation calibration:") + calibrate(calib_loader, quant_model) if args.gpfq: print("Performing GPFQ:")