diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index d6585a644..5198444b1 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -61,7 +61,7 @@ class ConstScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_init: Union[float, Tensor], - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -69,18 +69,12 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_init_module = Identity() + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_init_module = Identity() + scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method @@ -138,7 +132,7 @@ def __init__( self, scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -153,11 +147,10 @@ def __init__( scaling_init = scaling_init.detach() else: scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_init_module = Identity() + + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init)