From 52daf8668e4842e1f645640f85b14fc7c220276e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 13 Dec 2023 14:39:46 +0100 Subject: [PATCH] Fix (examples/generative): set weight_bit_width in weight_quant (#783) --- src/brevitas_examples/common/generative/quantize.py | 12 +++--------- src/brevitas_examples/llm/main.py | 1 - 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index c0d76559d..1d316d9a9 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -193,6 +193,7 @@ def quantize_model( # Modify the weight quantizer based on the arguments passed in weight_quant = weight_quant.let( **{ + 'bit_width': weight_bit_width, 'narrow_range': False, 'block_size': weight_group_size, 'quantize_zero_point': quantize_weight_zero_point}, @@ -311,15 +312,8 @@ def quantize_model( 'group_dim': 1, 'group_size': input_group_size}) quant_linear_kwargs = { - 'input_quant': linear_2d_input_quant, - 'weight_quant': weight_quant, - 'weight_bit_width': weight_bit_width, - 'dtype': dtype} - quant_conv_kwargs = { - 'input_quant': input_quant, - 'weight_quant': weight_quant, - 'weight_bit_width': weight_bit_width, - 'dtype': dtype} + 'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} + quant_conv_kwargs = {'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype} quant_mha_kwargs = { 'in_proj_input_quant': input_quant, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 17c0d5fe7..4d8b2c3ef 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -304,7 +304,6 @@ def main(): seqlen=args.seqlen) # Tie back first/last layer weights in case they got untied model.tie_weights() - print(model) print("Model quantization applied.") if args.act_calibration: