Skip to content

Commit

Permalink
New groupdim options for LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 1, 2024
1 parent fd51bb8 commit 6297b17
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def generate_quantizers(
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
weight_group_dim=None,
input_bit_width=None,
input_quant_format='',
input_scale_precision=None,
Expand Down Expand Up @@ -276,6 +277,10 @@ def generate_quantizers(
'narrow_range': False,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)

if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})

if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})
if weight_kwargs is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def main(args):
weight_quant_type=args.weight_quant_type,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
weight_group_dim=args.weight_group_dim,
quantize_weight_zero_point=args.quantize_weight_zero_point,
weight_quant_format=args.weight_quant_format,
input_bit_width=args.input_bit_width,
Expand Down Expand Up @@ -358,6 +359,12 @@ def parse_args(args):
default='per_group',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
parser.add_argument(
'--weight-group-dim',
type=int,
default=None,
choices=[1, 0],
help='Override default group_dim for groupsize quantization. Default: layer-dependant')
parser.add_argument(
'--weight-group-size',
type=int,
Expand Down

0 comments on commit 6297b17

Please sign in to comment.