diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..c00a8726a 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -67,12 +67,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scale / float_scaling_impl_value + else: + float_scaling_impl_value = torch.tensor(1.).type_as(x) + scale = self.scaling_impl(x, float_scaling_impl_value) x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..03307f829 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -149,9 +149,8 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: bit_width = self.msb_clamp_bit_width_impl() - threshold = self.scaling_impl(x) int_threshold = self.int_scaling_impl(bit_width) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index e4333186d..310f65717 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -51,9 +51,9 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: stats = self.parameter_list_stats() - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) class _StatsScaling(brevitas.jit.ScriptModule): @@ -80,8 +80,8 @@ def __init__( self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def forward(self, stats: torch.Tensor) -> torch.Tensor: - stats = self.restrict_scaling_pre(stats) + def forward(self, stats: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + stats = self.restrict_scaling_pre(stats / threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -120,9 +120,9 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: stats = self.runtime_stats(x) - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) class _AffineRescaling(brevitas.jit.ScriptModule): @@ -179,9 +179,9 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, stats_input) -> torch.Tensor: + def forward(self, stats_input: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) + out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 53f389331..c8e612909 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -77,8 +77,8 @@ def __init__( self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = self.value() + def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor: + value = self.value() / threshold restricted_value = self.restrict_clamp_scaling(value) return restricted_value @@ -149,8 +149,8 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) + def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor: + value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value / threshold)) return value def _load_from_state_dict( @@ -201,19 +201,21 @@ def __init__( self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: if self.init_done: - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = abs_binary_sign_grad( + self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold)) return value else: stats = self.parameter_list_stats() # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) inplace_tensor_mul(self.value.detach(), stats) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = abs_binary_sign_grad( + self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold)) self.init_done = True return value @@ -317,7 +319,7 @@ def __init__( self.restrict_preprocess = Identity() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -334,25 +336,27 @@ def training_forward(self, stats_input: Tensor) -> Tensor: inplace_momentum_update( self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter - return abs_binary_sign_grad(clamped_stats) + return abs_binary_sign_grad(clamped_stats) / threshold elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad( + self.clamp_scaling(self.restrict_scaling(self.value / threshold))) else: - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad( + self.clamp_scaling(self.restrict_scaling(self.value / threshold))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: if self.training: - return self.training_forward(stats_input) + return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer + out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value + out = self.value / threshold out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out