From f7aae4d255fcdcbf6ecf9fe9076c9747d72ea916 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 17:25:53 +0000 Subject: [PATCH 1/4] Fix (proxy): fix groupwise scale/zp caching --- src/brevitas/export/inference/handler.py | 12 ++++++++---- .../proxy/groupwise_float_parameter_quant.py | 18 ++++++++++++++++++ .../proxy/groupwise_int_parameter_quant.py | 18 ++++++++++++++++++ src/brevitas/proxy/parameter_quant.py | 15 ++++++++++++--- src/brevitas/utils/quant_utils.py | 14 ++++---------- 5 files changed, 60 insertions(+), 17 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index c7fc21790..7e3347a21 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -55,8 +55,10 @@ def __init__(self): def prepare_for_export(self, module: nn.Module): if module.is_quant_enabled: - self.scale = module.scale() - self.zero_point = module.zero_point().to(self.scale.device) + self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale() + self.zero_point = module.zero_point_() if hasattr( + module, 'zero_point_') else module.zero_point() + self.zero_point = self.zero_point.to(self.scale.device) self.bit_width = module.bit_width() self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) @@ -177,8 +179,10 @@ def __init__(self): def prepare_for_export(self, module): if module.is_quant_enabled: - self.scale = module.scale() - self.zero_point = module.zero_point().to(self.scale.device) + self.scale = module.scale_() if hasattr(module, 'scale_') else module.scale() + self.zero_point = module.zero_point_() if hasattr( + module, 'zero_point_') else module.zero_point() + self.zero_point = self.zero_point.to(self.scale.device) self.exponent_bit_width = module.exponent_bit_width() self.mantissa_bit_width = module.mantissa_bit_width() self.exponent_bias = module.exponent_bias() diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 9957848c1..510f8eae9 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -14,6 +14,24 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseFloat + def scale_(self): + if not self.is_quant_enabled: + return None + elif self._cached_weight: + scale = self._cached_weight.scale_ + else: + scale = self.__call__(self.tracked_parameter_list[0]).scale_ + return scale + + def zero_point_(self): + if not self.is_quant_enabled: + return None + elif self._cached_weight: + zero_point = self._cached_weight.zero_point_ + else: + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ + return zero_point + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 3c79a723b..43a2d9910 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -14,6 +14,24 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt + def scale_(self): + if not self.is_quant_enabled: + return None + elif self._cached_weight: + scale = self._cached_weight.scale_ + else: + scale = self.__call__(self.tracked_parameter_list[0]).scale_ + return scale + + def zero_point_(self): + if not self.is_quant_enabled: + return None + elif self._cached_weight: + zero_point = self._cached_weight.zero_point_ + else: + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ + return zero_point + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5c4e447d4..5f0c95518 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -195,19 +195,28 @@ def requires_quant_input(self): def scale(self): if not self.is_quant_enabled: return None - scale = self.__call__(self.tracked_parameter_list[0]).scale + elif self._cached_weight: + scale = self._cached_weight.scale + else: + scale = self.__call__(self.tracked_parameter_list[0]).scale return scale def zero_point(self): if not self.is_quant_enabled: return None - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point + elif self._cached_weight: + zero_point = self._cached_weight.zero_point + else: + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point return zero_point def bit_width(self): if not self.is_quant_enabled: return None - bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width + elif self._cached_weight: + bit_width = self._cached_weight.bit_width + else: + bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 62290b1de..a7c86d7bc 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -97,11 +97,8 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool) # torch.compile compatibility self.value = quant_tensor.value # torch.compile compatibility - self.scale = quant_tensor.scale - - @property - def zero_point(self): - return self.quant_tensor.zero_point + self.scale_ = quant_tensor.scale_ + self.zero_point_ = quant_tensor.zero_point_ @property def exponent_bit_width(self): @@ -152,11 +149,8 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): # torch.compile compatibility self.value = quant_tensor.value # torch.compile compatibility - self.scale = quant_tensor.scale - - @property - def zero_point(self): - return self.quant_tensor.zero_point + self.scale_ = quant_tensor.scale_ + self.zero_point_ = quant_tensor.zero_point_ @property def bit_width(self): From 0bcd6ca7f56ea84e5fa95abfff8276a9a7c31ba6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 18:23:14 +0000 Subject: [PATCH 2/4] minifloat fix --- src/brevitas/proxy/float_parameter_quant.py | 40 ++++++++++++++++----- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 88e78beaa..53765c13c 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -24,49 +24,73 @@ def bit_width(self): def scale(self): if not self.is_quant_enabled: return None - scale = self.__call__(self.tracked_parameter_list[0]).scale + elif self._cached_weight: + scale = self._cached_weight.scale + else: + scale = self.__call__(self.tracked_parameter_list[0]).scale return scale def zero_point(self): if not self.is_quant_enabled: return None - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point + elif self._cached_weight: + zero_point = self._cached_weight.zero_point + else: + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point return zero_point def exponent_bit_width(self): if not self.is_quant_enabled: return None - exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width + elif self._cached_weight: + exponent_bit_width = self._cached_weight.exponent_bit_width + else: + exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width return exponent_bit_width def mantissa_bit_width(self): if not self.is_quant_enabled: return None - mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width + elif self._cached_weight: + mantissa_bit_width = self._cached_weight.mantissa_bit_width + else: + mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width return mantissa_bit_width def exponent_bias(self): if not self.is_quant_enabled: return None - exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias + elif self._cached_weight: + exponent_bias = self._cached_weight.exponent_bias + else: + exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias return exponent_bias def is_saturating(self): if not self.is_quant_enabled: return None - saturating = self.__call__(self.tracked_parameter_list[0]).saturating + elif self._cached_weight: + saturating = self._cached_weight.saturating + else: + saturating = self.__call__(self.tracked_parameter_list[0]).saturating return saturating def inf_values(self): if not self.is_quant_enabled: return None - inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values + elif self._cached_weight: + inf_values = self._cached_weight.inf_values + else: + inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values return inf_values def nan_values(self): if not self.is_quant_enabled: return None - nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values + elif self._cached_weight: + nan_values = self._cached_weight.nan_values + else: + nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values return nan_values @property From c5d5c19fcdda2fddc7b481d0d7056251793beff5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 18:35:19 +0000 Subject: [PATCH 3/4] Cleanup --- src/brevitas/proxy/float_parameter_quant.py | 64 +++---------------- .../proxy/groupwise_float_parameter_quant.py | 16 +---- .../proxy/groupwise_int_parameter_quant.py | 16 +---- src/brevitas/proxy/parameter_quant.py | 33 ++++------ 4 files changed, 24 insertions(+), 105 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 53765c13c..de455db31 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -22,76 +22,28 @@ def bit_width(self): return bit_width def scale(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + self.retrieve_attribute('zero_point') def exponent_bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - exponent_bit_width = self._cached_weight.exponent_bit_width - else: - exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width - return exponent_bit_width + self.retrieve_attribute('exponent_bit_width') def mantissa_bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - mantissa_bit_width = self._cached_weight.mantissa_bit_width - else: - mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width - return mantissa_bit_width + self.retrieve_attribute('mantissa_bit_width') def exponent_bias(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - exponent_bias = self._cached_weight.exponent_bias - else: - exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias - return exponent_bias + self.retrieve_attribute('exponent_bias') def is_saturating(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - saturating = self._cached_weight.saturating - else: - saturating = self.__call__(self.tracked_parameter_list[0]).saturating - return saturating + self.retrieve_attribute('is_saturating') def inf_values(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - inf_values = self._cached_weight.inf_values - else: - inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values - return inf_values + self.retrieve_attribute('inf_values') def nan_values(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - nan_values = self._cached_weight.nan_values - else: - nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values - return nan_values + self.retrieve_attribute('nan_values') @property def is_ocp(self): diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 510f8eae9..7c55c1958 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseFloat def scale_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale_ - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale_ - return scale + self.retrieve_attribute('scale_') def zero_point_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point_ - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ - return zero_point + self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 43a2d9910..a9fd169f7 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -15,22 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseInt def scale_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale_ - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale_ - return scale + self.retrieve_attribute('scale_') def zero_point_(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point_ - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point_ - return zero_point + self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5f0c95518..85b74296d 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -118,6 +118,15 @@ def cache_inference_quant_weight(self, value): def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] + def retrieve_attribute(self, attribute: str): + if not self.is_quant_enabled: + return None + elif self._cached_weight is not None: + return getattr(self._cached_weight, attribute) + else: + out = self.__call__(self.tracked_parameter_list[0]) + return getattr(out, attribute) + @property def requires_quant_input(self): return False @@ -193,31 +202,13 @@ def requires_quant_input(self): return False def scale(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - scale = self._cached_weight.scale - else: - scale = self.__call__(self.tracked_parameter_list[0]).scale - return scale + self.retrieve_attribute('scale') def zero_point(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - zero_point = self._cached_weight.zero_point - else: - zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point - return zero_point + self.retrieve_attribute('zero_point') def bit_width(self): - if not self.is_quant_enabled: - return None - elif self._cached_weight: - bit_width = self._cached_weight.bit_width - else: - bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width - return bit_width + self.retrieve_attribute('bit_width') def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training) From 848cc7d443fe1244700feb54a10a2f7d382323f6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 20 Dec 2024 19:12:55 +0000 Subject: [PATCH 4/4] missing return --- src/brevitas/proxy/float_parameter_quant.py | 16 ++++++++-------- .../proxy/groupwise_float_parameter_quant.py | 4 ++-- .../proxy/groupwise_int_parameter_quant.py | 4 ++-- src/brevitas/proxy/parameter_quant.py | 6 +++--- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index de455db31..027625f11 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -22,28 +22,28 @@ def bit_width(self): return bit_width def scale(self): - self.retrieve_attribute('scale') + return self.retrieve_attribute('scale') def zero_point(self): - self.retrieve_attribute('zero_point') + return self.retrieve_attribute('zero_point') def exponent_bit_width(self): - self.retrieve_attribute('exponent_bit_width') + return self.retrieve_attribute('exponent_bit_width') def mantissa_bit_width(self): - self.retrieve_attribute('mantissa_bit_width') + return self.retrieve_attribute('mantissa_bit_width') def exponent_bias(self): - self.retrieve_attribute('exponent_bias') + return self.retrieve_attribute('exponent_bias') def is_saturating(self): - self.retrieve_attribute('is_saturating') + return self.retrieve_attribute('saturating') def inf_values(self): - self.retrieve_attribute('inf_values') + return self.retrieve_attribute('inf_values') def nan_values(self): - self.retrieve_attribute('nan_values') + return self.retrieve_attribute('nan_values') @property def is_ocp(self): diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 7c55c1958..12aacd23b 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -15,10 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseFloat def scale_(self): - self.retrieve_attribute('scale_') + return self.retrieve_attribute('scale_') def zero_point_(self): - self.retrieve_attribute('zero_point_') + return self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index a9fd169f7..905e50c52 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -15,10 +15,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = _CachedIOGroupwiseInt def scale_(self): - self.retrieve_attribute('scale_') + return self.retrieve_attribute('scale_') def zero_point_(self): - self.retrieve_attribute('zero_point_') + return self.retrieve_attribute('zero_point_') @property def group_dim(self): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 85b74296d..2ca0afe92 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -202,13 +202,13 @@ def requires_quant_input(self): return False def scale(self): - self.retrieve_attribute('scale') + return self.retrieve_attribute('scale') def zero_point(self): - self.retrieve_attribute('zero_point') + return self.retrieve_attribute('zero_point') def bit_width(self): - self.retrieve_attribute('bit_width') + return self.retrieve_attribute('bit_width') def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training)