From 6005c97113421f51a35125109fe4e3d3075e70ad Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 23:01:25 +0100 Subject: [PATCH 1/7] Fix (proxy): clean-up --- src/brevitas/nn/mixin/base.py | 5 - src/brevitas/proxy/float_parameter_quant.py | 56 +++------ src/brevitas/proxy/float_runtime_quant.py | 116 +++--------------- .../proxy/groupwise_float_parameter_quant.py | 39 +++--- .../proxy/groupwise_float_runtime_quant.py | 102 ++++++--------- .../proxy/groupwise_int_parameter_quant.py | 35 +++--- .../proxy/groupwise_int_runtime_quant.py | 79 ++++-------- src/brevitas/proxy/parameter_quant.py | 101 ++++++++------- src/brevitas/proxy/quant_proxy.py | 12 +- src/brevitas/proxy/runtime_quant.py | 92 +++++++------- src/brevitas/utils/quant_utils.py | 36 +++--- .../ptq/ptq_evaluate.py | 2 + 12 files changed, 261 insertions(+), 414 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 167852508..d64271cb5 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -69,9 +69,6 @@ def __init__(self, return_quant_tensor: bool): def channelwise_separable(self) -> bool: pass - def _set_global_is_quant_layer(self, value): - config._IS_INSIDE_QUANT_LAYER = value - def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): quant_tensor_classes = [ IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor] @@ -81,7 +78,6 @@ def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): return None def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and @@ -97,7 +93,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return inp def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - self._set_global_is_quant_layer(False) if self.return_quant_tensor: assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised' return quant_output diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 68038fa20..9bad29e79 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,5 +1,4 @@ -from typing import Optional, Union -from warnings import warn +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -8,7 +7,9 @@ from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOFloat @@ -84,46 +85,23 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 - def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return FloatQuantTensor( - out, - scale, - zero_point, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - else: # quantization disabled - return x - class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): - def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return FloatQuantTensor( - out, - scale, - zero_point, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - else: # quantization disabled - return x + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + return FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 021aefd12..eddfb0e17 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -60,105 +60,27 @@ def is_fnuz(self): ) is None and self.exponent_bias() == 16 return is_fnuz_e4m3 or is_fnuz_e5m2 - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, FloatQuantTensor): - out = FloatQuantTensor( - y, - x.scale, - x.zero_point, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y - else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, - FloatQuantTensor): - cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out - return out - class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): - def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, FloatQuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, FloatQuantTensor): - out = FloatQuantTensor( - y, - x.scale, - x.zero_point, - x.mantissa_bit_width, - x.exponent_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def __init__(self, quant_layer, quant_injector): + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOFloat + + def create_quant_tensor(self, qt_args, x=None): + if x is None: + out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, - FloatQuantTensor): - cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = FloatQuantTensor( + qt_args, + x.scale, + x.zero_point, + x.mantissa_bit_width, + x.exponent_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index d08033f8e..cbd91e463 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,6 +1,5 @@ -from typing import Union +from typing import Any, List, Union -import torch from torch import Tensor from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase @@ -17,23 +16,19 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) - return GroupwiseFloatQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - self.is_signed, - self.training) - else: # quantization disabled - return x + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseFloatQuantTensor]: + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + return GroupwiseFloatQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index b2aad4729..25d496192 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -1,15 +1,14 @@ -from typing import Union - -from torch import Tensor - from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor -from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat class GroupwiseActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer, quant_injector): + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseFloat + @property def group_dim(self): return self.quant_injector.group_dim @@ -18,67 +17,36 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloatQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty GroupwiseFloatQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y - out = GroupwiseFloatQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - saturating, - inf_values, - nan_values, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseFloatQuantTensor): - out = GroupwiseFloatQuantTensor( - y, - x.scale, - x.zero_point, - self.group_size, - self.group_dim, - x.exponent_bit_width, - x.mantissa_bit_width, - x.exponent_bias, - x.saturating, - x.inf_values, - x.nan_values, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def create_quant_tensor(self, qt_args, x=None): + if x is None: + value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args + out = GroupwiseFloatQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance( - out, GroupwiseFloatQuantTensor): - cached_out = _CachedIOGroupwiseFloat(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = GroupwiseFloatQuantTensor( + qt_args, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.exponent_bit_width, + x.mantissa_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 35892daeb..f8f1a266c 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -1,19 +1,16 @@ -from typing import Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import GroupwiseIntQuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseInt class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Is this always generated? - self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl - @property def group_dim(self): return self.quant_injector.group_dim @@ -22,18 +19,14 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width = impl(x) - return GroupwiseIntQuantTensor( - out, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - self.is_signed, - self.training) - else: # quantization disabled - return x + def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseIntQuantTensor]: + out, scale, zero_point, bit_width = qt_args + return GroupwiseIntQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index e9788e89b..566148d81 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -1,5 +1,6 @@ from typing import Union +import torch from torch import Tensor from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector @@ -10,6 +11,10 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): + def __init__(self, quant_layer, quant_injector): + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseInt + @property def group_dim(self): return self.quant_injector.group_dim @@ -18,58 +23,26 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty GroupwiseIntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - # We exclude the last two values (inf_values and nan_values) - if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - value, scale, zero_point, bit_width, = y - out = GroupwiseIntQuantTensor( - value, - scale, - zero_point, - self.group_size, - self.group_dim, - bit_width, - signed=self.is_signed, - training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, GroupwiseIntQuantTensor): - out = GroupwiseIntQuantTensor( - y, - x.scale, - x.zero_point, - self.group_size, - self.group_dim, - x.bit_width, - x.signed, - self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def create_quant_tensor(self, qt_args, x=None): + if x is None: + value, scale, zero_point, bit_width, = qt_args + out = GroupwiseIntQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance( - out, GroupwiseIntQuantTensor): - cached_out = _CachedIOGroupwiseInt(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = GroupwiseIntQuantTensor( + qt_args, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.bit_width, + x.signed, + self.training) return out diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index fc4e75cb9..4b9fa8195 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,7 +4,7 @@ from abc import ABC from abc import ABCMeta from abc import abstractmethod -from typing import Optional, Union +from typing import Any, Optional, Tuple, Union from warnings import warn import torch @@ -84,6 +84,24 @@ class WeightQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, WeightQuantProxyProtocol, ABC): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_weight = None + self._cache_inference_quant_weight = False + self.cache_inference_quant_weight_metadata_only = False + self.cache_class = None # To be redefined by each class + self.quant_tensor_class = None # To be redefined by each class + + @property + def cache_inference_quant_weight(self): + return self._cache_inference_quant_weight + + @cache_inference_quant_weight.setter + def cache_inference_quant_weight(self, value): + if not value: + self._cached_weight = None + self._cache_inference_quant_weight = value + @property def tracked_parameter_list(self): return [m.weight for m in self.tracked_module_list if m.weight is not None] @@ -92,6 +110,27 @@ def tracked_parameter_list(self): def requires_quant_input(self): return False + @abstractmethod + def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]: + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + if self.is_quant_enabled: + if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + out = self._cached_weight.quant_tensor + else: + impl = self.export_handler if self.export_mode else self.tensor_quant + out = impl(x) + out = self.create_quant_tensor(out) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = self.cache_class( + out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) + else: + out = out[0] + else: # quantization disabled + out = x + return out + class BiasQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol, ABC): @@ -99,18 +138,13 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) self._cached_bias = None self.cache_inference_quant_bias = False + self.cache_inference_quant_bias_metadata_only = False + self.requires_input_scale = self.quant_injector.requires_input_scale @property def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] - @property - def requires_input_scale(self) -> bool: - if self.is_quant_enabled: - return self.quant_injector.requires_input_scale - else: - return False - def get_cached(self, attr): if self._cached_bias is None: warn( @@ -126,8 +160,7 @@ class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) - self._cached_weight = None - self.cache_inference_quant_weight = False + self.cache_class = _CachedIO @property def tracked_parameter_list(self): @@ -155,22 +188,8 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width - def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: - if self.is_quant_enabled: - if self._cached_weight is not None: - out = self._cached_weight.quant_tensor - else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width = impl(x) - out = IntQuantTensor( - out, scale, zero_point, bit_width, self.is_signed, self.training) - else: # quantization disabled - out = x - if isinstance( - out, IntQuantTensor - ) and not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = _CachedIO(out.detach(), metadata_only=False) - return out + def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]: + return IntQuantTensor(*qt_args, self.is_signed, self.training) class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -189,13 +208,9 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: - if self.is_quant_enabled: - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) - return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) - else: # quantization disabled - return x + def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]: + out, scale, zero_point, bit_width, pre_scale, pre_zero_point = qt_args + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): @@ -249,7 +264,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: return None - if self.requires_input_scale: + if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled: cache = self.get_cached('scale') return cache zhs = self._zero_hw_sentinel() @@ -282,7 +297,7 @@ def compute_bias_scale( self, input: Optional[Union[Tensor, IntQuantTensor]], weight: Optional[Union[Tensor, IntQuantTensor]]) -> Optional[Tensor]: - if not self.requires_input_scale: + if not self.requires_input_scale and self.is_quant_enabled: return None if not isinstance(input, IntQuantTensor) or not isinstance(weight, IntQuantTensor): return None @@ -305,23 +320,23 @@ def forward( input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant - if self.requires_input_scale and input_scale is None: + if self.requires_input_scale and input_scale is None and self.is_quant_enabled: input_scale = self.scale() if input_scale is None: raise RuntimeError("Input scale required") - - if self.requires_input_scale: + elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled: input_scale = input_scale.view(-1) + + if self.requires_input_scale and self.is_quant_enabled: out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO( + out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x - if isinstance(out, - IntQuantTensor) and not self.training and self.cache_inference_quant_bias: - cached_bias = _CachedIO(out.detach(), metadata_only=False) - self._cached_bias = cached_bias return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 3a680035e..1b847b280 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -28,12 +28,6 @@ def _is_groupwise(quant_injector): return False -def _is_signed(quant_injector): - if 'signed' in quant_injector: - return quant_injector.signed - return None - - def _is_narrow_range(quant_injector): if 'narrow_range' in quant_injector: return quant_injector.narrow_range @@ -88,6 +82,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.tracked_module_list = [] self.add_tracked_module(quant_layer) self.disable_quant = False + # Torch.compile compatibility requires this + self.is_signed = quant_injector.signed if 'signed' in quant_injector else None @property def requires_export_handler(self): @@ -108,10 +104,6 @@ def init_tensor_quant(self): def is_quant_enabled(self): return not self.disable_quant and self.tensor_quant is not None - @property - def is_signed(self): - return _is_signed(self.quant_injector) - @property def is_groupwise(self): return _is_groupwise(self.quant_injector) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index a89bc9abb..5d69e67e5 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from abc import abstractmethod from typing import Optional, Tuple, Union from torch import nn @@ -94,6 +95,7 @@ def __init__(self, quant_layer, quant_injector): self._cached_act = None self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True + self.cache_class = None def internal_forward(self, force_eval): current_status = self.training @@ -116,12 +118,6 @@ def retrieve_attribute(self, attribute, force_eval): def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant - @property - def is_signed(self): - if self._cached_act is not None: - return self._cached_act.signed - return super().is_signed - @is_quant_enabled.setter def is_quant_enabled(self, is_quant_enabled): self._is_quant_enabled = is_quant_enabled @@ -145,9 +141,53 @@ def init_tensor_quant(self): else: self.fused_activation_quant_proxy = None + @abstractmethod + def create_quant_tensor(self, qt_args, x=None): + raise NotImplementedError + + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + # If fused activation quant proxy is not enabled, return the input + if self.fused_activation_quant_proxy is None: + return x + + y = x + if isinstance(y, QuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + # A tuple helps later with control flows + # The second None value is used later + y = (self.fused_activation_quant_proxy.activation_impl(y), None) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + + # If the second value (i.e., scale) is None, then quant is disabled + if isinstance(y, tuple) and y[1] is not None: + out = self.create_quant_tensor(y) + elif self.is_passthrough_act and isinstance(x, QuantTensor): + # preserve scale/zp/bit/sign even without output quant + y = y[0] + out = self.create_quant_tensor(y, x=x) + else: + out = y[0] + + if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out + class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase): + def __init__(self, quant_layer, quant_injector): + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIO + def scale(self, force_eval=True): return self.retrieve_attribute('scale', force_eval) @@ -157,42 +197,12 @@ def zero_point(self, force_eval=True): def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, IntQuantTensor): - out = IntQuantTensor( - y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y + def create_quant_tensor(self, qt_args, x=None): + if x is None: + out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor): - cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out + out = IntQuantTensor( + qt_args, x.scale, x.zero_point, x.bit_width, x.signed, self.training) return out diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index f160877a0..6fd519b41 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -16,13 +16,14 @@ class _CachedIO: def __init__(self, quant_tensor: IntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -42,13 +43,14 @@ class _CachedIOFloat: def __init__(self, quant_tensor: FloatQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -88,13 +90,14 @@ class _CachedIOGroupwiseFloat: def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): @@ -142,13 +145,14 @@ class _CachedIOGroupwiseInt: def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: + self.value = None self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale + # torch.compile compatibility + self.value = quant_tensor.value + # torch.compile compatibility + self.scale = quant_tensor.scale @property def zero_point(self): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7e2bf6ee5..8fd2f655d 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,6 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.graph.calibrate import inference_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize @@ -365,6 +366,7 @@ def main(): # Get the model from torchvision model = get_torchvision_model(args.model_name) model = model.to(dtype) + model.eval() # Preprocess the model for quantization if args.target_backend == 'flexml': From 8111ae0ee568c27296cba5f4e666d856db0d3254 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Aug 2024 10:16:11 +0100 Subject: [PATCH 2/7] cleanup --- src/brevitas/proxy/float_parameter_quant.py | 9 ++------- src/brevitas/proxy/float_runtime_quant.py | 12 +++--------- .../proxy/groupwise_float_parameter_quant.py | 6 ++---- src/brevitas/proxy/groupwise_float_runtime_quant.py | 7 ++++++- src/brevitas/proxy/groupwise_int_parameter_quant.py | 9 ++------- src/brevitas/proxy/groupwise_int_runtime_quant.py | 11 +++++------ src/brevitas/proxy/parameter_quant.py | 4 ++-- src/brevitas/proxy/quant_proxy.py | 1 - src/brevitas/proxy/runtime_quant.py | 6 ++++-- 9 files changed, 26 insertions(+), 39 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 9bad29e79..5d1920a8d 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,15 +1,10 @@ -from typing import Any, List, Optional, Union +from typing import Any, Optional, Tuple, Union -import torch from torch import Tensor -import torch.nn as nn -from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensor -from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOFloat @@ -88,7 +83,7 @@ def is_fnuz(self): class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): - def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor: out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return FloatQuantTensor( out, diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index eddfb0e17..6119cba03 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -1,14 +1,7 @@ -from typing import Optional, Union -from warnings import warn +from typing import Any, Optional, Tuple -import torch -from torch import Tensor -import torch.nn as nn - -from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor -from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOFloat @@ -67,7 +60,8 @@ def __init__(self, quant_layer, quant_injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOFloat - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, qt_args: Tuple[Any], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: if x is None: out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index cbd91e463..7dab9bb93 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,6 +1,4 @@ -from typing import Any, List, Union - -from torch import Tensor +from typing import Any, Tuple from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor @@ -16,7 +14,7 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseFloatQuantTensor]: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return GroupwiseFloatQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 25d496192..7d7828ea2 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -1,3 +1,5 @@ +from typing import Any, Optional, Tuple + from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat @@ -17,7 +19,10 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, + qt_args: Tuple[Any], + x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: if x is None: value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args out = GroupwiseFloatQuantTensor( diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index f8f1a266c..d0cd10334 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -1,12 +1,7 @@ -from typing import Any, List, Optional, Union - -import torch -from torch import Tensor +from typing import Any, List from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import GroupwiseIntQuantTensor -from brevitas.utils.quant_utils import _CachedIOGroupwiseInt class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -19,7 +14,7 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseIntQuantTensor]: + def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor: out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 566148d81..34ab22619 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -1,11 +1,7 @@ -from typing import Union - -import torch -from torch import Tensor +from typing import Any, Optional, Tuple from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor -from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIOGroupwiseInt @@ -23,7 +19,10 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, + qt_args: Tuple[Any], + x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: value, scale, zero_point, bit_width, = qt_args out = GroupwiseIntQuantTensor( diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 4b9fa8195..604a43c00 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -188,7 +188,7 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width - def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: return IntQuantTensor(*qt_args, self.is_signed, self.training) @@ -208,7 +208,7 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor: out, scale, zero_point, bit_width, pre_scale, pre_zero_point = qt_args return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 1b847b280..2d2ed10c1 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABCMeta -from abc import abstractmethod from typing import Optional from torch import nn diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 5d69e67e5..14dfdb564 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -3,7 +3,7 @@ from abc import ABC from abc import abstractmethod -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union from torch import nn from torch import Tensor @@ -197,7 +197,9 @@ def zero_point(self, force_eval=True): def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, qt_args: Tuple[Any], x: Optional[IntQuantTensor] = None) -> IntQuantTensor: + if x is None: out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: From bdc316fd89de91936951cb15472a01d620c13467 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Aug 2024 10:19:02 +0100 Subject: [PATCH 3/7] fix ptq script --- .../imagenet_classification/ptq/ptq_evaluate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 8fd2f655d..3a9bb29fa 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,8 +18,6 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq -from brevitas.graph.calibrate import inference_mode -from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization From 12bf43f99b71085c71cd2a879f9e972b53114f87 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Aug 2024 10:41:41 +0100 Subject: [PATCH 4/7] fix quant proxy --- src/brevitas/proxy/parameter_quant.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 604a43c00..1f9cc62b6 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -125,8 +125,6 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) - else: - out = out[0] else: # quantization disabled out = x return out From 4162ef3bfda8aaaa055eeeb2e9f5044282953e74 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 31 Aug 2024 10:54:06 +0100 Subject: [PATCH 5/7] review --- src/brevitas/proxy/float_parameter_quant.py | 9 ++++++- src/brevitas/proxy/float_runtime_quant.py | 15 ++++++++--- .../proxy/groupwise_float_parameter_quant.py | 8 ++++++ .../proxy/groupwise_float_runtime_quant.py | 6 +++-- .../proxy/groupwise_int_parameter_quant.py | 12 +++++++-- .../proxy/groupwise_int_runtime_quant.py | 6 +++-- src/brevitas/proxy/parameter_quant.py | 12 ++++++--- src/brevitas/proxy/runtime_quant.py | 25 +++++++++++++------ 8 files changed, 72 insertions(+), 21 deletions(-) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 5d1920a8d..4e6452792 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,14 +1,17 @@ +from abc import ABC from typing import Any, Optional, Tuple, Union from torch import Tensor +import torch.nn as nn +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor from brevitas.utils.quant_utils import _CachedIOFloat -class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase): +class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC): def scale(self): if not self.is_quant_enabled: @@ -83,6 +86,10 @@ def is_fnuz(self): class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOFloat + def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor: out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return FloatQuantTensor( diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 6119cba03..b38f4ecdb 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -1,11 +1,16 @@ -from typing import Any, Optional, Tuple +from abc import ABC +from typing import Any, Optional, Tuple, Union +import torch +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor from brevitas.utils.quant_utils import _CachedIOFloat -class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase): +class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase, ABC): def scale(self, force_eval=True): return self.retrieve_attribute('scale', force_eval) @@ -56,12 +61,14 @@ def is_fnuz(self): class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): - def __init__(self, quant_layer, quant_injector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOFloat def create_quant_tensor( - self, qt_args: Tuple[Any], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: + self, + qt_args: Union[torch.Tensor, Tuple[Any]], + x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: if x is None: out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 7dab9bb93..10b80d8a6 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,11 +1,19 @@ from typing import Any, Tuple +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseFloat + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 7d7828ea2..c98ff0eaf 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union + +import torch from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor @@ -21,7 +23,7 @@ def group_size(self): def create_quant_tensor( self, - qt_args: Tuple[Any], + qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: if x is None: value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index d0cd10334..b4049cb55 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -1,11 +1,19 @@ -from typing import Any, List +from typing import Any, Tuple +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseInt class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseInt + @property def group_dim(self): return self.quant_injector.group_dim @@ -14,7 +22,7 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 34ab22619..42e595fd0 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union + +import torch from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor @@ -21,7 +23,7 @@ def group_size(self): def create_quant_tensor( self, - qt_args: Tuple[Any], + qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: value, scale, zero_point, bit_width, = qt_args diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 1f9cc62b6..e7818359e 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -116,11 +116,17 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor] def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + # If quant is enabled the priority is: + # - export mode + # - cached weight + # - quantization flow + if self.export_mode: + out = self.export_handler(x) + out = self.create_quant_tensor(out) + elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: out = self._cached_weight.quant_tensor else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out = impl(x) + out = self.tensor_quant(x) out = self.create_quant_tensor(out) if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 14dfdb564..b06778d5b 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union +import torch from torch import nn from torch import Tensor from torch.nn import Identity @@ -142,7 +143,14 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None @abstractmethod - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, + qt_args: Union[torch.Tensor, Tuple[Any]], + x: Optional[QuantTensor] = None) -> QuantTensor: + # Supports the following: + # - qt_args as tuple of Tensors and bools = standard quant activations + # - qt_args as Tensor and x as QuantTensor = passthrough activation + # In both cases, the output is a QuantTensor raise NotImplementedError def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: @@ -160,21 +168,24 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - y = (self.fused_activation_quant_proxy.activation_impl(y), None) + y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor + # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor # If the second value (i.e., scale) is None, then quant is disabled if isinstance(y, tuple) and y[1] is not None: out = self.create_quant_tensor(y) elif self.is_passthrough_act and isinstance(x, QuantTensor): - # preserve scale/zp/bit/sign even without output quant - y = y[0] + # preserve quant_metadata + if isinstance(y, tuple): + y = y[0] out = self.create_quant_tensor(y, x=x) else: - out = y[0] + if isinstance(y, tuple): + y = y[0] + out = y if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) From 71c93c8ed2d17476b6eddfe7eb4877cfd7f6a318 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 3 Sep 2024 14:43:58 +0100 Subject: [PATCH 6/7] Partially revert `4162ef3bfda8aaaa055eeeb2e9f5044282953e74` -> forward function of float proxy. --- src/brevitas/proxy/runtime_quant.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b06778d5b..0093d97c6 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -168,24 +168,21 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - y = self.fused_activation_quant_proxy.activation_impl(y) + y = (self.fused_activation_quant_proxy.activation_impl(y), None) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor + # If the second value (i.e., scale) is None, then quant is disabled if isinstance(y, tuple) and y[1] is not None: out = self.create_quant_tensor(y) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve quant_metadata - if isinstance(y, tuple): - y = y[0] + y = y[0] out = self.create_quant_tensor(y, x=x) else: - if isinstance(y, tuple): - y = y[0] - out = y + out = y[0] if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) From 16847f47e6cdb933926a0e11427303f7b130bed7 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 3 Sep 2024 15:00:54 +0100 Subject: [PATCH 7/7] Fix (proxy/runtime): Typo fix on type hint --- src/brevitas/proxy/runtime_quant.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 0093d97c6..4ec52e47c 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -206,7 +206,9 @@ def bit_width(self, force_eval=True): return self.retrieve_attribute('bit_width', force_eval) def create_quant_tensor( - self, qt_args: Tuple[Any], x: Optional[IntQuantTensor] = None) -> IntQuantTensor: + self, + qt_args: Union[Tensor, Tuple[Any]], + x: Optional[IntQuantTensor] = None) -> IntQuantTensor: if x is None: out = IntQuantTensor(*qt_args, self.is_signed, self.training)