Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (proxy): clean-up #1011

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def __init__(self, return_quant_tensor: bool):
def channelwise_separable(self) -> bool:
pass

def _set_global_is_quant_layer(self, value):
config._IS_INSIDE_QUANT_LAYER = value

def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
quant_tensor_classes = [
IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor]
Expand All @@ -81,7 +78,6 @@ def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and
Expand All @@ -97,7 +93,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
return inp

def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised'
return quant_output
Expand Down
62 changes: 21 additions & 41 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional, Union
from warnings import warn
from abc import ABC
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor
import torch.nn as nn

Expand All @@ -12,7 +11,7 @@
from brevitas.utils.quant_utils import _CachedIOFloat


class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase):
class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC):

def scale(self):
if not self.is_quant_enabled:
Expand Down Expand Up @@ -84,46 +83,27 @@ def is_fnuz(self):
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
else: # quantization disabled
return x


class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
else: # quantization disabled
return x
def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return FloatQuantTensor(
out,
scale,
zero_point,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)


class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
Expand Down
127 changes: 25 additions & 102 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Optional, Union
from warnings import warn
from abc import ABC
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase):
class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase, ABC):

def scale(self, force_eval=True):
return self.retrieve_attribute('scale', force_eval)
Expand Down Expand Up @@ -60,105 +58,30 @@ def is_fnuz(self):
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.exponent_bit_width,
x.mantissa_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out,
FloatQuantTensor):
cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
return out


class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase):

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, FloatQuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, FloatQuantTensor):
out = FloatQuantTensor(
y,
x.scale,
x.zero_point,
x.mantissa_bit_width,
x.exponent_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
else:
out = y
else:
if isinstance(y, tuple):
y = y[0]
out = y
def __init__(self, quant_layer: nn.Module, quant_injector: Injector):
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
if x is None:
out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training)
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out,
FloatQuantTensor):
cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
out = FloatQuantTensor(
qt_args,
x.scale,
x.zero_point,
x.mantissa_bit_width,
x.exponent_bit_width,
x.exponent_bias,
x.saturating,
x.inf_values,
x.nan_values,
x.signed,
self.training)
return out
47 changes: 24 additions & 23 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Union
from typing import Any, Tuple

import torch
from torch import Tensor
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat


class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseFloat

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand All @@ -17,23 +22,19 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
else: # quantization disabled
return x
def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
out,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
self.is_signed,
self.training)
Loading
Loading