Skip to content

Commit

Permalink
fix for input_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 29, 2024
1 parent 79c5811 commit 726356a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,14 @@ def generate_quantizers(
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
sym_input_quant = sym_input_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
linear_input_quant = linear_input_quant.let(
**{'restrict_value_float_to_int_impl': scale_type})
if input_quant is not None:
input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
if sym_input_quant is not None:
sym_input_quant = sym_input_quant.let(
**{'restrict_value_float_to_int_impl': scale_type})
if linear_input_quant is not None:
linear_input_quant = linear_input_quant.let(
**{'restrict_value_float_to_int_impl': scale_type})

if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})
Expand Down

0 comments on commit 726356a

Please sign in to comment.