From 630f32eb9c28696fcf5d3bf0917cd325de5c0528 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 00:17:54 +0100 Subject: [PATCH] Review --- src/brevitas/core/quant/binary.py | 4 +-- src/brevitas/core/quant/float.py | 2 +- src/brevitas/core/quant/ternary.py | 2 +- src/brevitas/core/scaling/pre_scaling.py | 4 +-- src/brevitas/core/scaling/runtime.py | 19 ++++++++--- src/brevitas/core/scaling/standalone.py | 43 +++++++++++++++++++----- 6 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 2be1d23c0..3a4b7346e 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -58,7 +58,7 @@ def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - scale = self.scaling_impl(x, torch.tensor(1.).type_as(x)) + scale = self.scaling_impl(x) y = binary_sign_ste(x) * scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() @@ -118,7 +118,7 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - scale = self.scaling_impl(x, torch.tensor(1.).type_as(x)) + scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) y = binary_sign_ste(y) * scale y = self.delay_wrapper(x, y) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index c00a8726a..145f5ca06 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -72,7 +72,7 @@ def quantize(self, x: torch.Tensor): float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) else: - float_scaling_impl_value = torch.tensor(1.).type_as(x) + float_scaling_impl_value = None scale = self.scaling_impl(x, float_scaling_impl_value) x = self.input_view_impl(x) scaled_x = x / scale diff --git a/src/brevitas/core/quant/ternary.py b/src/brevitas/core/quant/ternary.py index 552468477..ffaa873de 100644 --- a/src/brevitas/core/quant/ternary.py +++ b/src/brevitas/core/quant/ternary.py @@ -61,7 +61,7 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - scale = self.scaling_impl(x, torch.tensor(1.).type_as(x)) + scale = self.scaling_impl(x) mask = x.abs().gt(self.threshold * scale) y = mask.float() * ternary_sign_ste(x) y = y * scale diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 82d12b298..d73c86461 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -97,7 +97,7 @@ def forward(self, weights: Tensor) -> Tensor: weights = self.stats_input_view_shape_impl(weights) d_w = self.stats(weights) # denominator for weight normalization g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g - s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s + s = self.scaling_impl(weights) # s value = (s * d_w) / g return value @@ -184,7 +184,7 @@ def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Te def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool): weights = self.stats_input_view_shape_impl(weights) d_w = self.stats(weights) # denominator for weight normalization - s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s + s = self.scaling_impl(weights) # s g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s g = torch.clamp_max(g / s, T) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 310f65717..ca2baea17 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -51,7 +51,10 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + def forward( + self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) stats = self.parameter_list_stats() return self.stats_scaling_impl(stats, threshold) @@ -80,7 +83,10 @@ def __init__( self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def forward(self, stats: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + def forward( + self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) stats = self.restrict_scaling_pre(stats / threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) @@ -120,7 +126,7 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) return self.stats_scaling_impl(stats, threshold) @@ -179,7 +185,12 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, stats_input: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + def forward( + self, + stats_input: torch.Tensor, + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 29da7bbff..7ef576d65 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -77,7 +77,9 @@ def __init__( self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) value = self.value() / threshold restricted_value = self.restrict_clamp_scaling(value) return restricted_value @@ -149,7 +151,9 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold) return value @@ -197,14 +201,23 @@ def __init__( self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) if restrict_scaling_impl is not None: self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() else: self.restrict_inplace_preprocess = Identity() + self.restrict_preprocess = Identity() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + def forward( + self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.init_done: - value = self.restrict_inplace_preprocess(self.value / threshold) + value = self.restrict_preprocess(self.value / threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: @@ -214,7 +227,7 @@ def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tenso if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_inplace_preprocess(self.value / threshold) + value = self.restrict_preprocess(self.value / threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) self.init_done = True return value @@ -231,6 +244,10 @@ def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): value_key = prefix + 'value' + + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) if config._RETROCOMPATIBLE_SCALING: if not isinstance(self.restrict_scaling_impl, Identity): state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( @@ -325,7 +342,11 @@ def __init__( self.restrict_preprocess = Identity() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: + def training_forward( + self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -335,6 +356,7 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens new_counter = self.counter + 1 # Whenever we are in local loss mode, we don't update the counter nor the buffer if self.local_loss_mode: + # Local loss mode, we early exit and divide by threshold return abs_binary_sign_grad(clamped_stats / threshold) if self.counter == 0: inplace_tensor_mul(self.buffer, clamped_stats.detach()) @@ -346,7 +368,6 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens elif self.counter == self.collect_stats_steps: inplace_tensor_mul(self.value.detach(), self.buffer) value = self.restrict_preprocess(self.value / threshold) - # self.restrict_inplace_preprocess(self.value / threshold) self.counter = self.counter + 1 return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: @@ -354,8 +375,11 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) if self.training: + # Threshold division handled inside the training_forward return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: @@ -388,6 +412,9 @@ def _load_from_state_dict( if retrocomp_value_key in state_dict: state_dict[value_key] = state_dict.pop(retrocomp_value_key) + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) if config._RETROCOMPATIBLE_SCALING: if not isinstance(self.restrict_scaling_impl, Identity): state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op(