diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 5d1920a8d..4e6452792 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -1,14 +1,17 @@ +from abc import ABC from typing import Any, Optional, Tuple, Union from torch import Tensor +import torch.nn as nn +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor from brevitas.utils.quant_utils import _CachedIOFloat -class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase): +class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC): def scale(self): if not self.is_quant_enabled: @@ -83,6 +86,10 @@ def is_fnuz(self): class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOFloat + def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor: out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return FloatQuantTensor( diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 6119cba03..b38f4ecdb 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -1,11 +1,16 @@ -from typing import Any, Optional, Tuple +from abc import ABC +from typing import Any, Optional, Tuple, Union +import torch +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.quant_tensor import FloatQuantTensor from brevitas.utils.quant_utils import _CachedIOFloat -class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase): +class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase, ABC): def scale(self, force_eval=True): return self.retrieve_attribute('scale', force_eval) @@ -56,12 +61,14 @@ def is_fnuz(self): class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): - def __init__(self, quant_layer, quant_injector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOFloat def create_quant_tensor( - self, qt_args: Tuple[Any], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: + self, + qt_args: Union[torch.Tensor, Tuple[Any]], + x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor: if x is None: out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training) else: diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 7dab9bb93..10b80d8a6 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -1,11 +1,19 @@ from typing import Any, Tuple +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseFloat + @property def group_dim(self): return self.quant_injector.group_dim diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 7d7828ea2..c98ff0eaf 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union + +import torch from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase from brevitas.quant_tensor import GroupwiseFloatQuantTensor @@ -21,7 +23,7 @@ def group_size(self): def create_quant_tensor( self, - qt_args: Tuple[Any], + qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor: if x is None: value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index d0cd10334..b4049cb55 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -1,11 +1,19 @@ -from typing import Any, List +from typing import Any, Tuple +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseInt class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self.cache_class = _CachedIOGroupwiseInt + @property def group_dim(self): return self.quant_injector.group_dim @@ -14,7 +22,7 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size - def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor: + def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 34ab22619..42e595fd0 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union + +import torch from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant_tensor import GroupwiseIntQuantTensor @@ -21,7 +23,7 @@ def group_size(self): def create_quant_tensor( self, - qt_args: Tuple[Any], + qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: value, scale, zero_point, bit_width, = qt_args diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 1f9cc62b6..e7818359e 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -116,11 +116,17 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor] def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: - if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: + # If quant is enabled the priority is: + # - export mode + # - cached weight + # - quantization flow + if self.export_mode: + out = self.export_handler(x) + out = self.create_quant_tensor(out) + elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: out = self._cached_weight.quant_tensor else: - impl = self.export_handler if self.export_mode else self.tensor_quant - out = impl(x) + out = self.tensor_quant(x) out = self.create_quant_tensor(out) if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 14dfdb564..b06778d5b 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union +import torch from torch import nn from torch import Tensor from torch.nn import Identity @@ -142,7 +143,14 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None @abstractmethod - def create_quant_tensor(self, qt_args, x=None): + def create_quant_tensor( + self, + qt_args: Union[torch.Tensor, Tuple[Any]], + x: Optional[QuantTensor] = None) -> QuantTensor: + # Supports the following: + # - qt_args as tuple of Tensors and bools = standard quant activations + # - qt_args as Tensor and x as QuantTensor = passthrough activation + # In both cases, the output is a QuantTensor raise NotImplementedError def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: @@ -160,21 +168,24 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - y = (self.fused_activation_quant_proxy.activation_impl(y), None) + y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor + # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor # If the second value (i.e., scale) is None, then quant is disabled if isinstance(y, tuple) and y[1] is not None: out = self.create_quant_tensor(y) elif self.is_passthrough_act and isinstance(x, QuantTensor): - # preserve scale/zp/bit/sign even without output quant - y = y[0] + # preserve quant_metadata + if isinstance(y, tuple): + y = y[0] out = self.create_quant_tensor(y, x=x) else: - out = y[0] + if isinstance(y, tuple): + y = y[0] + out = y if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)