From 8e4d42ed72324e2e6651faa49c19e53b524c381c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 10:21:15 +0100 Subject: [PATCH 01/10] Fix (groupwise): correct log and groupdim --- src/brevitas/core/scaling/runtime.py | 4 ++++ src/brevitas/quant/solver/common.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f11eb1f2a..0e6037903 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -187,6 +187,8 @@ def __init__( self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_module = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) @brevitas.jit.script_method def forward( @@ -197,6 +199,8 @@ def forward( threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) / threshold + # Apply log scaling + out = self.restrict_module(out) # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 4d46cc704..a4930e43d 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -178,7 +178,8 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return group_dim + 1 + reduce_dim = group_dim + 1 if group_dim != -1 else -1 + return reduce_dim @value def keepdim(scaling_per_output): From cac72892f5b25f8eb8c622b9388370bd147abe2b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 12:56:14 +0100 Subject: [PATCH 02/10] More fix --- src/brevitas/core/function_wrapper/shape.py | 5 +++-- src/brevitas/core/scaling/runtime.py | 5 +++-- src/brevitas/quant/experimental/mx_quant_ocp.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index e175e4445..e8b42312a 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -195,8 +195,9 @@ def forward(self, x): tensor_shape = x.shape tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list[self.group_dim] = ( + tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size + block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list) tensor_shape_list.insert(block_dim, self.group_size) x = x.view(tensor_shape_list) return x diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 0e6037903..2dc4cea1c 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -198,9 +198,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) / threshold + threshold = self.restrict_clamp_scaling(self.restrict_module(threshold)) + out = self.scaling_stats_impl(stats_input_reshaped) # Apply log scaling out = self.restrict_module(out) # Scaling min val - out = self.restrict_clamp_scaling(out) + out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 2299c1783..551f4f3d7 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -4,6 +4,7 @@ from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.function_wrapper.ops_ste import FloorSte from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -46,14 +47,14 @@ class GroupwiseActProxyMixin(ExtendedInjector): class MXWeightMixin(ExtendedInjector): group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP class MXActMixin(ExtendedInjector): group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_impl = RuntimeDynamicGroupStatsScaling scaling_per_output_type = ScalingPerOutputType.GROUP From 090e0339162681ec4db17ad1c29f4aa29056bd5c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Oct 2024 21:57:23 +0100 Subject: [PATCH 03/10] More fixes --- src/brevitas/core/restrict_val.py | 12 -------- src/brevitas/core/scaling/runtime.py | 3 +- src/brevitas/core/scaling/standalone.py | 38 ++++++++++++++---------- tests/brevitas/graph/test_calibration.py | 3 +- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 59b3fe8ec..7eb9845f9 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,9 +90,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: return x @@ -116,9 +113,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.power_of_two(x) @@ -143,9 +137,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) @@ -171,9 +162,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 2dc4cea1c..fee4175bc 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -90,10 +90,11 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats) threshold = self.restrict_scaling_pre(threshold) + threshold = self.restrict_clamp_scaling(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) + stats = stats / threshold return stats diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 4917b859a..e43fd577a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -220,9 +220,10 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.restrict_inplace_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold return value else: stats = self.parameter_list_stats() @@ -231,10 +232,11 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) - threshold = self.restrict_inplace_preprocess(threshold) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold self.init_done = True return value @@ -360,14 +362,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + return abs_binary_sign_grad(value) else: - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold + return abs_binary_sign_grad(value) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -378,12 +382,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer / threshold + out = self.buffer out = self.restrict_preprocess(out) else: - threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) + out = self.value + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + out = self.clamp_scaling(self.restrict_scaling(out)) + out = out / threshold + out = abs_binary_sign_grad(self.clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index fbfc76842..16f944e97 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -60,7 +60,7 @@ def reference_implementation_scale_factors_po2( return scale -@given(inp=float_tensor_random_size_st()) +@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10)) def test_scale_factors_ptq_calibration_po2(inp): class TestModel(nn.Module): @@ -80,7 +80,6 @@ def forward(self, x): expected_scale = reference_implementation_scale_factors_po2(inp) scale = model.act.act_quant.scale() - assert torch.allclose(expected_scale, scale) From ff72df76576c02c07c40ead0064620c7b40c10b5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 12:24:56 +0000 Subject: [PATCH 04/10] Decouple threshold restrict impl from scaling --- src/brevitas/core/scaling/runtime.py | 47 +++++++++++----- src/brevitas/core/scaling/standalone.py | 75 ++++++++++++++++++------- 2 files changed, 89 insertions(+), 33 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index fee4175bc..9792ebdae 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,12 +30,18 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, @@ -44,6 +50,7 @@ def __init__( tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, + restrict_threshold_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, @@ -81,16 +89,18 @@ def __init__( else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() - self.restrict_scaling_impl = restrict_scaling_impl + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) - threshold = self.restrict_scaling_pre(threshold) - threshold = self.restrict_clamp_scaling(threshold) + threshold = self.restrict_threshold_pre(threshold) + threshold = self.restrict_clamp_threshold(threshold) stats = self.restrict_scaling_pre(stats) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) @@ -108,12 +118,17 @@ def __init__( affine_rescaling: bool = False, affine_shift_scale: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, @@ -123,6 +138,7 @@ def __init__( device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -174,13 +190,14 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - scaling_stats_impl: Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Module = FloatRestrictValue()) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim @@ -188,7 +205,11 @@ def __init__( self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) - self.restrict_module = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) + self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module( ) @brevitas.jit.script_method @@ -199,10 +220,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - threshold = self.restrict_clamp_scaling(self.restrict_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) out = self.scaling_stats_impl(stats_input_reshaped) # Apply log scaling - out = self.restrict_module(out) + out = self.restrict_scaling_pre(out) # Scaling min val out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index e43fd577a..d9347898f 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -62,20 +62,27 @@ def __init__( self, scaling_init: Union[float, Tensor], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ConstScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -83,7 +90,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) restricted_value = self.restrict_clamp_scaling(self.value()) restricted_value = restricted_value / threshold return restricted_value @@ -133,11 +140,16 @@ def __init__( scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + if (isinstance(scaling_init, Tensor) and scaling_shape is not None and scaling_init.shape != SCALAR_SHAPE and scaling_init.shape != scaling_shape): raise RuntimeError("scaling_init.shape is non-scalar and != from scaling_shape.") @@ -149,12 +161,14 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -162,7 +176,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) return value / threshold @@ -193,6 +207,7 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -203,13 +218,26 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) - self.restrict_scaling_impl = restrict_scaling_impl + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.stats_scaling_impl = _StatsScaling( - restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) + restrict_scaling_impl, + restrict_threshold_impl, + scaling_shape, + scaling_min_val, + False, + False, + dtype, + device) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.restrict_inplace_scaling_pre = restrict_scaling_impl.restrict_init_inplace_module() + self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method @@ -220,8 +248,8 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.stats_scaling_impl.restrict_clamp_scaling( - self.restrict_preprocess(threshold)) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) value = value / threshold return value @@ -231,9 +259,9 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) - threshold = self.stats_scaling_impl.restrict_clamp_scaling( - self.restrict_preprocess(threshold)) + stats = self.restrict_inplace_scaling_pre(stats) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) inplace_tensor_mul(self.value.detach(), stats) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) value = value / threshold @@ -314,12 +342,18 @@ def __init__( scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.collect_stats_steps: int = brevitas.jit.Attribute(collect_stats_steps, int) self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl @@ -328,13 +362,14 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) - self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) + self.restrict_threshold = _RestrictValue(restrict_threshold_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: @@ -362,13 +397,13 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) value = self.clamp_scaling(self.restrict_scaling(self.value)) value = value / threshold self.counter = self.counter + 1 return abs_binary_sign_grad(value) else: - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) value = self.clamp_scaling(self.restrict_scaling(self.value)) value = value / threshold return abs_binary_sign_grad(value) @@ -383,10 +418,10 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te else: if self.counter <= self.collect_stats_steps: out = self.buffer - out = self.restrict_preprocess(out) + out = self.restrict_scaling_pre(out) else: out = self.value - threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) out = self.clamp_scaling(self.restrict_scaling(out)) out = out / threshold out = abs_binary_sign_grad(self.clamp_scaling(out)) From 2e34865fce4d998553a0a6b3fe8ed3880fbeaf1b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:08:43 +0000 Subject: [PATCH 05/10] Clean-up --- src/brevitas/core/restrict_val.py | 5 ++++- .../quant/experimental/mx_quant_ocp.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 7eb9845f9..7d6d83231 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -24,7 +24,10 @@ class _RestrictClampValue(brevitas.jit.ScriptModule): - def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]): + def __init__( + self, + scaling_min_val: Optional[float] = None, + restrict_value_impl: Optional[Module] = None): super(_RestrictClampValue, self).__init__() if scaling_min_val is not None and scaling_min_val != 0: self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 551f4f3d7..5900fe663 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,10 +1,13 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from dependencies import this from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte from brevitas.core.function_wrapper.ops_ste import FloorSte +from brevitas.core.restrict_val import PowerOfTwo +from brevitas.core.restrict_val import PowerOfTwoRestrictValue from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -44,14 +47,25 @@ class GroupwiseActProxyMixin(ExtendedInjector): proxy_class = GroupwiseActQuantProxyFromInjector +class RestrictThresholdMixin(ExtendedInjector): + restrict_value_float_to_int_impl = FloorSte + restrict_scaling_impl = PowerOfTwoRestrictValue + + class MXWeightMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXActMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = FloorSte @@ -66,6 +80,10 @@ def stats_reduce_dim(group_dim): else: return group_dim + 1 + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXFloat8e4m3Weight(MXWeightMixin, GroupwiseWeightFloatProxyMixin, From e07085e4a7c855728ab4f2ec7263d64091593f8b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:22:39 +0000 Subject: [PATCH 06/10] fix --- src/brevitas/core/scaling/standalone.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index d9347898f..13ead5afc 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -244,9 +244,6 @@ def __init__( def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(ignored) - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependant on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.init_done: threshold = self.stats_scaling_impl.restrict_clamp_threshold( self.restrict_threshold_pre(threshold)) @@ -373,9 +370,6 @@ def __init__( @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependent on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -437,7 +431,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: - output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer) + output_dict[prefix + 'value'] = self.restrict_scaling_pre(self.buffer) return output_dict def _load_from_state_dict( From 79c5811c704ed942cf2ce509c6862ceef23434a2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 20:56:50 +0000 Subject: [PATCH 07/10] New flag --- src/brevitas_examples/common/generative/quantize.py | 13 +++++++++++++ src/brevitas_examples/llm/main.py | 7 +++++++ 2 files changed, 20 insertions(+) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 9460fadf1..3cbdadcfa 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -8,6 +8,9 @@ from torch import nn from brevitas import nn as qnn +from brevitas.core.function_wrapper import CeilSte +from brevitas.core.function_wrapper import FloorSte +from brevitas.core.restrict_val import RoundSte from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize from brevitas.quant.experimental.float import Fp8e4m3Act @@ -220,6 +223,7 @@ def generate_quantizers( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, + scale_rounding_func_type=None, device=None, weight_kwargs=None, input_kwargs=None): @@ -278,6 +282,15 @@ def generate_quantizers( 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + if scale_rounding_func_type is not None: + scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte} + scale_type = scale_rounding_func_dict[scale_rounding_func_type] + weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + sym_input_quant = sym_input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + linear_input_quant = linear_input_quant.let( + **{'restrict_value_float_to_int_impl': scale_type}) + if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4a87f5a1a..4a2df3a66 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -253,6 +253,7 @@ def main(args): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, + scale_rounding_func_type=args.scale_rounding_func_type, device=device) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, @@ -400,6 +401,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--scale-rounding-func-type', + type=str, + default=None, + choices=['round', 'ceil', 'floor'], + help='Rounding function to use with Po2 scale. Default: None.') parser.add_argument( '--weight-group-dim', type=int, From 726356ae23b71f2eb35b1cb0bd5e07ca98100b5d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 29 Oct 2024 08:52:09 +0000 Subject: [PATCH 08/10] fix for input_quant --- src/brevitas_examples/common/generative/quantize.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 3cbdadcfa..eaabf4d81 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -286,10 +286,14 @@ def generate_quantizers( scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte} scale_type = scale_rounding_func_dict[scale_rounding_func_type] weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) - input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) - sym_input_quant = sym_input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) - linear_input_quant = linear_input_quant.let( - **{'restrict_value_float_to_int_impl': scale_type}) + if input_quant is not None: + input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + if sym_input_quant is not None: + sym_input_quant = sym_input_quant.let( + **{'restrict_value_float_to_int_impl': scale_type}) + if linear_input_quant is not None: + linear_input_quant = linear_input_quant.let( + **{'restrict_value_float_to_int_impl': scale_type}) if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) From 3ae82a4b1fb6710fdcd9394c113d1195c5c558d6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 29 Oct 2024 09:10:49 +0000 Subject: [PATCH 09/10] Add default --- src/brevitas/core/scaling/runtime.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 9792ebdae..09f891ed7 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -199,6 +199,11 @@ def __init__( restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.group_size = group_size self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl From bb36a270dc589fbe9b84a852f79da239550a99eb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 30 Oct 2024 13:25:31 +0100 Subject: [PATCH 10/10] Empty commit for tests