Skip to content

Commit

Permalink
Fix (examples/generative): set weight_bit_width in weight_quant (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 13, 2023
1 parent 84f4225 commit 52daf86
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
12 changes: 3 additions & 9 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def quantize_model(
# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
**{
'bit_width': weight_bit_width,
'narrow_range': False,
'block_size': weight_group_size,
'quantize_zero_point': quantize_weight_zero_point},
Expand Down Expand Up @@ -311,15 +312,8 @@ def quantize_model(
'group_dim': 1, 'group_size': input_group_size})

quant_linear_kwargs = {
'input_quant': linear_2d_input_quant,
'weight_quant': weight_quant,
'weight_bit_width': weight_bit_width,
'dtype': dtype}
quant_conv_kwargs = {
'input_quant': input_quant,
'weight_quant': weight_quant,
'weight_bit_width': weight_bit_width,
'dtype': dtype}
'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
quant_conv_kwargs = {'input_quant': input_quant, 'weight_quant': weight_quant, 'dtype': dtype}

quant_mha_kwargs = {
'in_proj_input_quant': input_quant,
Expand Down
1 change: 0 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def main():
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print(model)
print("Model quantization applied.")

if args.act_calibration:
Expand Down

0 comments on commit 52daf86

Please sign in to comment.