From c5d5c19fcdda2fddc7b481d0d7056251793beff5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 18:35:19 +0000 Subject: [PATCH] Cleanup --- src/brevitas/proxy/float_parameter_quant.py | 64 +++---------------- .../proxy/groupwise_float_parameter_quant.py | 16 +---- .../proxy/groupwise_int_parameter_quant.py | 16 +---- src/brevitas/proxy/parameter_quant.py | 33 ++++------ 4 files changed, 24 insertions(+), 105 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 53765c13c..de455db31 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -22,76 +22,28 @@ def bit_width(self): return bit_width def scale(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + self.retrieve_attribute('zero_point') def exponent_bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - exponent_bit_width = self._cached_weight.exponent_bit_width - else: - exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width - return exponent_bit_width + self.retrieve_attribute('exponent_bit_width') def mantissa_bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - mantissa_bit_width = self._cached_weight.mantissa_bit_width - else: - mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width - return mantissa_bit_width + self.retrieve_attribute('mantissa_bit_width') def exponent_bias(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - exponent_bias = self._cached_weight.exponent_bias - else: - exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias - return exponent_bias + self.retrieve_attribute('exponent_bias') def is_saturating(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - saturating = self._cached_weight.saturating - else: - saturating = self.__call__(self.tracked_parameter_list[0]).saturating - return saturating + self.retrieve_attribute('is_saturating') def inf_values(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - inf_values = self._cached_weight.inf_values - else: - inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values - return inf_values + self.retrieve_attribute('inf_values') def nan_values(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - nan_values = self._cached_weight.nan_values - else: - nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values - return nan_values + self.retrieve_attribute('nan_values') @property def is_ocp(self): diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 510f8eae9..7c55c1958 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseFloat def scale_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale_ - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale_ - return scale + self.retrieve_attribute('scale_') def zero_point_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point_ - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ - return zero_point + self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 43a2d9910..a9fd169f7 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseInt def scale_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale_ - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale_ - return scale + self.retrieve_attribute('scale_') def zero_point_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point_ - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ - return zero_point + self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5f0c95518..85b74296d 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -118,6 +118,15 @@ def cache_inference_quant_weight(self, value): def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] + def retrieve_attribute(self, attribute: str): + if not self.is_quant_enabled: + return None + elif self._cached_weight is not None: + return getattr(self._cached_weight, attribute) + else: + out = self.__call__(self.tracked_parameter_list[0]) + return getattr(out, attribute) + @property def requires_quant_input(self): return False @@ -193,31 +202,13 @@ def requires_quant_input(self): return False def scale(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + self.retrieve_attribute('zero_point') def bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - bit_width = self._cached_weight.bit_width - else: - bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width - return bit_width + self.retrieve_attribute('bit_width') def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training)