diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index c3d521964..71f518bb5 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -14,7 +14,7 @@ class FloatQuant(brevitas.jit.ScriptModule): - __constants__ = ['signed'] + __constants__ = ['signed', 'eps'] def __init__( self, diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index f838bcf4f..96fafebee 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -96,7 +96,7 @@ def float_internal_scale( x: torch.Tensor, mantissa_bit_width: torch.Tensor, fp_internal_scale_min: torch.Tensor, - eps: torch.Tensor) -> torch.Tensor: + eps: float) -> torch.Tensor: internal_scale = floor_ste(torch.log2(torch.abs(x) + eps)) - mantissa_bit_width internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min)