Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (groupwise): correct log, groupdim, and scale computation #1071

Merged
merged 10 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list[self.group_dim] = (
tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size
block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x
Expand Down
17 changes: 4 additions & 13 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

class _RestrictClampValue(brevitas.jit.ScriptModule):

def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]):
def __init__(
self,
scaling_min_val: Optional[float] = None,
restrict_value_impl: Optional[Module] = None):
super(_RestrictClampValue, self).__init__()
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
Expand Down Expand Up @@ -90,9 +93,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return x
Expand All @@ -116,9 +116,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.power_of_two(x)
Expand All @@ -143,9 +140,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -171,9 +165,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand Down
56 changes: 44 additions & 12 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ def __init__(
tracked_parameter_list: List[torch.nn.Parameter],
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(StatsFromParameterScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.parameter_list_stats = _ParameterListStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -44,6 +50,7 @@ def __init__(
tracked_parameter_list)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand All @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule):
def __init__(
self,
restrict_scaling_impl: Module,
restrict_threshold_impl: Module,
scaling_shape: Tuple[int, ...],
scaling_min_val: Optional[float],
affine_rescaling: bool,
Expand All @@ -81,19 +89,22 @@ def __init__(
else:
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_scaling_impl = restrict_scaling_impl
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
threshold = self.restrict_threshold_pre(threshold)
threshold = self.restrict_clamp_threshold(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
stats = stats / threshold
return stats


Expand All @@ -107,12 +118,17 @@ def __init__(
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(RuntimeStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.runtime_stats = _RuntimeStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -122,6 +138,7 @@ def __init__(
device)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand Down Expand Up @@ -173,20 +190,32 @@ def _load_from_state_dict(
class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module(
)
self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module(
)

@brevitas.jit.script_method
def forward(
Expand All @@ -196,7 +225,10 @@ def forward(
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
out = self.scaling_stats_impl(stats_input_reshaped)
# Apply log scaling
out = self.restrict_scaling_pre(out)
# Scaling min val
out = self.restrict_clamp_scaling(out)
out = self.restrict_clamp_scaling(out) / threshold
return out
Loading
Loading