diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..20a513907 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -64,11 +64,10 @@ def __init__( if dtype is None: dtype = torch.get_default_dtype() self.eps = torch.finfo(dtype).tiny + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method - def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) - + def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor]: if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) @@ -86,10 +85,15 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - y, scale = self.quantize(x) - # after quantizing, clamp to special cases like NaN/inf if they are set - y, saturating, inf_values, nan_values = self.float_clamp_impl( - y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = self.scaling_impl(x) + if self.observer_only: + y = x + saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values + else: + y, scale = self.quantize(x, scale) + # after quantizing, clamp to special cases like NaN/inf if they are set + y, saturating, inf_values, nan_values = self.float_clamp_impl( + y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..e1cc271d8 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -145,6 +145,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor, threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 74da08e19..67b57df6a 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -189,16 +189,16 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -@brevitas.jit.ignore +def max_mantissa_func(val): + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} + + def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = torch.sum(( - 2. ** torch.arange( - 0, - -1. * mantissa_bit_width - 1., - -1., - dtype=mantissa_bit_width.dtype, - device=mantissa_bit_width.device))) + max_mantissa = MAX_MANTISSA_DICT[mantissa_bit_width.item()] max_val = max_mantissa * (2 ** max_exponent) return max_val diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 2b1f6833e..6335d6d45 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -201,8 +201,9 @@ def disable_act_quantization(self, model, is_training): if isinstance(module, ActQuantProxyFromInjectorBase): module.train(is_training) if self.call_act_quantizer_impl: - hook = module.register_forward_hook(self.disable_act_quant_hook) - self.disable_act_quant_hooks.append(hook) + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = True else: module.disable_quant = True elif isinstance(module, _ACC_PROXIES): @@ -229,9 +230,9 @@ def enable_act_quantization(self, model, is_training): elif isinstance(module, ActQuantProxyFromInjectorBase): module.disable_quant = False module.train(is_training) - for hook in self.disable_act_quant_hooks: - hook.remove() - self.disable_act_quant_hooks = [] + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = False def enable_param_quantization(self, model, is_training): for module in model.modules():