From 0ea7bac8f7d7b687c1ac0c8cb4712ad9885645c5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 6 Dec 2024 09:52:45 +0000 Subject: [PATCH] Fix (brevitas_examples/generative): scaling_min_val for any dtype (#1117) --- src/brevitas_examples/common/generative/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 6c156ec1a..778955285 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -318,7 +318,7 @@ def generate_quantizers( if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) - if dtype != torch.float32: + if scaling_min_val is not None: weight_quant = weight_quant.let(**{'scaling_min_val': scaling_min_val}) input_quant = input_quant.let( **{'scaling_min_val': scaling_min_val}) if input_quant is not None else None