diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index f19431605..5098ab3f0 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -187,9 +187,13 @@ def quantize_model( **{ 'bit_width': weight_bit_width, 'narrow_range': False, - 'group_size': weight_group_size, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + + # Set the group_size is we're doing groupwise quantization + if weight_quant_granularity == 'per_group': + weight_quant = weight_quant.let( + **{'group_size': weight_group_size}) # weight scale is converted to a standalone parameter # This is done already by default in the per_group quantizer if weight_quant_granularity != 'per_group':