diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index c7fc21790..7e3347a21 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -55,8 +55,10 @@ def __init__(self): def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: - self.scale = module.scale() - self.zero_point = module.zero_point().to(self.scale.device) + self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale() + self.zero_point = module.zero_point_() if hasattr( + module, 'zero_point_') else module.zero_point() + self.zero_point = self.zero_point.to(self.scale.device) self.bit_width = module.bit_width() self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) @@ -177,8 +179,10 @@ def __init__(self): def prepare_for_export(self, module): if module.is_quant_enabled: - self.scale = module.scale() - self.zero_point = module.zero_point().to(self.scale.device) + self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale() + self.zero_point = module.zero_point_() if hasattr( + module, 'zero_point_') else module.zero_point() + self.zero_point = self.zero_point.to(self.scale.device) self.exponent_bit_width = module.exponent_bit_width() self.mantissa_bit_width = module.mantissa_bit_width() self.exponent_bias = module.exponent_bias() diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 88e78beaa..027625f11 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -22,52 +22,28 @@ def bit_width(self): return bit_width def scale(self): - if not self.is_quant_enabled: - return None - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + return self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + return self.retrieve_attribute('zero_point') def exponent_bit_width(self): - if not self.is_quant_enabled: - return None - exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width - return exponent_bit_width + return self.retrieve_attribute('exponent_bit_width') def mantissa_bit_width(self): - if not self.is_quant_enabled: - return None - mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width - return mantissa_bit_width + return self.retrieve_attribute('mantissa_bit_width') def exponent_bias(self): - if not self.is_quant_enabled: - return None - exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias - return exponent_bias + return self.retrieve_attribute('exponent_bias') def is_saturating(self): - if not self.is_quant_enabled: - return None - saturating = self.__call__(self.tracked_parameter_list[0]).saturating - return saturating + return self.retrieve_attribute('saturating') def inf_values(self): - if not self.is_quant_enabled: - return None - inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values - return inf_values + return self.retrieve_attribute('inf_values') def nan_values(self): - if not self.is_quant_enabled: - return None - nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values - return nan_values + return self.retrieve_attribute('nan_values') @property def is_ocp(self): diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 9957848c1..12aacd23b 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -14,6 +14,12 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseFloat + def scale_(self): + return self.retrieve_attribute('scale_') + + def zero_point_(self): + return self.retrieve_attribute('zero_point_') + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 3c79a723b..905e50c52 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -14,6 +14,12 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt + def scale_(self): + return self.retrieve_attribute('scale_') + + def zero_point_(self): + return self.retrieve_attribute('zero_point_') + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5c4e447d4..2ca0afe92 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -118,6 +118,15 @@ def cache_inference_quant_weight(self, value): def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] + def retrieve_attribute(self, attribute: str): + if not self.is_quant_enabled: + return None + elif self._cached_weight is not None: + return getattr(self._cached_weight, attribute) + else: + out = self.__call__(self.tracked_parameter_list[0]) + return getattr(out, attribute) + @property def requires_quant_input(self): return False @@ -193,22 +202,13 @@ def requires_quant_input(self): return False def scale(self): - if not self.is_quant_enabled: - return None - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + return self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + return self.retrieve_attribute('zero_point') def bit_width(self): - if not self.is_quant_enabled: - return None - bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width - return bit_width + return self.retrieve_attribute('bit_width') def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training) diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 62290b1de..a7c86d7bc 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -97,11 +97,8 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool) # torch.compile compatibility self.value = quant_tensor.value # torch.compile compatibility - self.scale = quant_tensor.scale - - @property - def zero_point(self): - return self.quant_tensor.zero_point + self.scale_ = quant_tensor.scale_ + self.zero_point_ = quant_tensor.zero_point_ @property def exponent_bit_width(self): @@ -152,11 +149,8 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): # torch.compile compatibility self.value = quant_tensor.value # torch.compile compatibility - self.scale = quant_tensor.scale - - @property - def zero_point(self): - return self.quant_tensor.zero_point + self.scale_ = quant_tensor.scale_ + self.zero_point_ = quant_tensor.zero_point_ @property def bit_width(self):