Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2024
1 parent 0bcd6ca commit c5d5c19
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 105 deletions.
64 changes: 8 additions & 56 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,76 +22,28 @@ def bit_width(self):
return bit_width

def scale(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale
self.retrieve_attribute('scale')

def zero_point(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point
self.retrieve_attribute('zero_point')

def exponent_bit_width(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
exponent_bit_width = self._cached_weight.exponent_bit_width
else:
exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width
return exponent_bit_width
self.retrieve_attribute('exponent_bit_width')

def mantissa_bit_width(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
mantissa_bit_width = self._cached_weight.mantissa_bit_width
else:
mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width
return mantissa_bit_width
self.retrieve_attribute('mantissa_bit_width')

def exponent_bias(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
exponent_bias = self._cached_weight.exponent_bias
else:
exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias
return exponent_bias
self.retrieve_attribute('exponent_bias')

def is_saturating(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
saturating = self._cached_weight.saturating
else:
saturating = self.__call__(self.tracked_parameter_list[0]).saturating
return saturating
self.retrieve_attribute('is_saturating')

def inf_values(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
inf_values = self._cached_weight.inf_values
else:
inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values
return inf_values
self.retrieve_attribute('inf_values')

def nan_values(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
nan_values = self._cached_weight.nan_values
else:
nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values
return nan_values
self.retrieve_attribute('nan_values')

@property
def is_ocp(self):
Expand Down
16 changes: 2 additions & 14 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = _CachedIOGroupwiseFloat

def scale_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale_
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale_
return scale
self.retrieve_attribute('scale_')

def zero_point_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point_
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_
return zero_point
self.retrieve_attribute('zero_point_')

@property
def group_dim(self):
Expand Down
16 changes: 2 additions & 14 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = _CachedIOGroupwiseInt

def scale_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale_
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale_
return scale
self.retrieve_attribute('scale_')

def zero_point_(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point_
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_
return zero_point
self.retrieve_attribute('zero_point_')

@property
def group_dim(self):
Expand Down
33 changes: 12 additions & 21 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def cache_inference_quant_weight(self, value):
def tracked_parameter_list(self):
return [m.weight for m in self.tracked_module_list if m.weight is not None]

def retrieve_attribute(self, attribute: str):
if not self.is_quant_enabled:
return None
elif self._cached_weight is not None:
return getattr(self._cached_weight, attribute)
else:
out = self.__call__(self.tracked_parameter_list[0])
return getattr(out, attribute)

@property
def requires_quant_input(self):
return False
Expand Down Expand Up @@ -193,31 +202,13 @@ def requires_quant_input(self):
return False

def scale(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
scale = self._cached_weight.scale
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale
self.retrieve_attribute('scale')

def zero_point(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
zero_point = self._cached_weight.zero_point
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point
self.retrieve_attribute('zero_point')

def bit_width(self):
if not self.is_quant_enabled:
return None
elif self._cached_weight:
bit_width = self._cached_weight.bit_width
else:
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width
self.retrieve_attribute('bit_width')

def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor:
return IntQuantTensor(*qt_args, self.is_signed, self.training)
Expand Down

0 comments on commit c5d5c19

Please sign in to comment.