Skip to content

Commit 52daf86

Browse files
authored
Fix (examples/generative): set weight_bit_width in weight_quant (#783)
1 parent 84f4225 commit 52daf86

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

src/brevitas_examples/common/generative/quantize.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def quantize_model(
193193
# Modify the weight quantizer based on the arguments passed in
194194
weight_quant = weight_quant.let(
195195
**{
196+
'bit_width': weight_bit_width,
196197
'narrow_range': False,
197198
'block_size': weight_group_size,
198199
'quantize_zero_point': quantize_weight_zero_point},
@@ -311,15 +312,8 @@ def quantize_model(
311312
'group_dim': 1, 'group_size': input_group_size})
312313

313314
quant_linear_kwargs = {
314-
'input_quant': linear_2d_input_quant,
315-
'weight_quant': weight_quant,
316-
'weight_bit_width': weight_bit_width,
317-
'dtype': dtype}
318-
quant_conv_kwargs = {
319-
'input_quant': input_quant,
320-
'weight_quant': weight_quant,
321-
'weight_bit_width': weight_bit_width,
322-
'dtype': dtype}
315+
'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
316+
quant_conv_kwargs = {'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
323317

324318
quant_mha_kwargs = {
325319
'in_proj_input_quant': input_quant,

src/brevitas_examples/llm/main.py

-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ def main():
304304
seqlen=args.seqlen)
305305
# Tie back first/last layer weights in case they got untied
306306
model.tie_weights()
307-
print(model)
308307
print("Model quantization applied.")
309308

310309
if args.act_calibration:

0 commit comments

Comments
 (0)