From 6297b173b73ae98687a42d1a09dfde342880f6eb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 12:25:41 +0100 Subject: [PATCH] New groupdim options for LLM --- src/brevitas_examples/common/generative/quantize.py | 5 +++++ src/brevitas_examples/llm/main.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 5259c6776..15b47884f 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -210,6 +210,7 @@ def generate_quantizers( weight_group_size, quantize_weight_zero_point, weight_quant_format='int', + weight_group_dim=None, input_bit_width=None, input_quant_format='', input_scale_precision=None, @@ -276,6 +277,10 @@ def generate_quantizers( 'narrow_range': False, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + + if weight_group_dim is not None: + weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) + if dtype == torch.float16: weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) if weight_kwargs is not None: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e19390774..67c8144c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -226,6 +226,7 @@ def main(args): weight_quant_type=args.weight_quant_type, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, + weight_group_dim=args.weight_group_dim, quantize_weight_zero_point=args.quantize_weight_zero_point, weight_quant_format=args.weight_quant_format, input_bit_width=args.input_bit_width, @@ -358,6 +359,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-group-dim', + type=int, + default=None, + choices=[1, 0], + help='Override default group_dim for groupsize quantization. Default: layer-dependant') parser.add_argument( '--weight-group-size', type=int,