diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 7e9b9c897..d1e8fbb0d 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -318,7 +318,7 @@ def generate_quantizers( if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) - if dtype != torch.float32: + if scaling_min_val is not None: weight_quant = weight_quant.let(**{'scaling_min_val': scaling_min_val}) input_quant = input_quant.let( **{'scaling_min_val': scaling_min_val}) if input_quant is not None else None