From 8c2454df78b7a7485c1394188f15437b4d54444f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 29 Oct 2024 09:10:49 +0000 Subject: [PATCH] Add default --- src/brevitas/core/scaling/runtime.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 9792ebdae..09f891ed7 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -199,6 +199,11 @@ def __init__( restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.group_size = group_size self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl