diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..65f56a134 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -64,11 +64,10 @@ def __init__( if dtype is None: dtype = torch.get_default_dtype() self.eps = torch.finfo(dtype).tiny + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method - def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) - + def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) @@ -86,10 +85,15 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - y, scale = self.quantize(x) - # after quantizing, clamp to special cases like NaN/inf if they are set - y, saturating, inf_values, nan_values = self.float_clamp_impl( - y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - y = self.dequantize(y, scale) + scale = self.scaling_impl(x) + if self.observer_only: + y = x + saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values + else: + y, scale = self.quantize(x, scale) + # after quantizing, clamp to special cases like NaN/inf if they are set + y, saturating, inf_values, nan_values = self.float_clamp_impl( + y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..e1cc271d8 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -145,6 +145,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor, threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 29d4d06e8..ac520a707 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled): m.local_loss_mode = enabled +def _set_observer_mode(module, enabled, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + previous_observer_mode[m] = m.observer_only + m.observer_only = enabled + + +def _restore_observer_mode(module, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = previous_observer_mode[m] + + class MSE(torch.nn.Module): # References: # https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py @@ -459,7 +472,12 @@ def __init__( self.mse_init_op = mse_init_op self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.num = mse_iters self.search_method = mse_search_method @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate): self.internal_candidate = candidate # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_value = self.proxy_forward(x) quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) + self.restore_observer_mode() return loss def mse_grid_search(self, xl, x): @@ -571,7 +591,12 @@ def __init__( self.hqo_init_op = hqo_init_op_scale self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.hqo_iters = hqo_iters_scale self.stats_reduce_dim = stats_reduce_dim @@ -598,8 +623,10 @@ def parameter_search(self, xl, x): for i in range(0, self.hqo_iters): self.internal_candidate = candidate self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_tensor = self.proxy_forward(x).detach() self.set_local_loss_mode(False) + self.restore_observer_mode() loss = torch.abs(quant_tensor.value - x).mean() best_candidate = torch.where(loss < best_loss, candidate, best_candidate) @@ -670,7 +697,12 @@ def __init__( self.hqo_init_op_zp = hqo_init_op_zp self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.stats_reduce_dim = stats_reduce_dim self.local_loss_mode: bool = False @@ -688,8 +720,10 @@ def parameter_search(self, xl, x): for i in range(0, self.hqo_iters): self.internal_candidate = candidate self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_tensor = self.proxy_forward(x).detach() self.set_local_loss_mode(False) + self.restore_observer_mode() qt_value = self.input_view_shape_impl(quant_tensor.value) qt_scale = self.input_view_shape_impl(quant_tensor.scale) qt_zp = self.input_view_shape_impl(quant_tensor.zero_point) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index c0fc9efdb..9c753952e 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -200,8 +200,9 @@ def disable_act_quantization(self, model, is_training): if isinstance(module, ActQuantProxyFromInjectorBase): module.train(is_training) if self.call_act_quantizer_impl: - hook = module.register_forward_hook(self.disable_act_quant_hook) - self.disable_act_quant_hooks.append(hook) + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = True else: module.disable_quant = True elif isinstance(module, _ACC_PROXIES): @@ -228,9 +229,9 @@ def enable_act_quantization(self, model, is_training): elif isinstance(module, ActQuantProxyFromInjectorBase): module.disable_quant = False module.train(is_training) - for hook in self.disable_act_quant_hooks: - hook.remove() - self.disable_act_quant_hooks = [] + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = False def enable_param_quantization(self, model, is_training): for module in model.modules(): diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..552472717 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format): signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) - - out_quant, scale = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + out_quant, scale = float_quant.quantize(inp, scale) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - _ = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width),