Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 28, 2024
1 parent 6005c97 commit 8111ae0
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 39 deletions.
9 changes: 2 additions & 7 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from typing import Any, List, Optional, Union
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


Expand Down Expand Up @@ -88,7 +83,7 @@ def is_fnuz(self):

class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, QuantTensor]:
def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return FloatQuantTensor(
out,
Expand Down
12 changes: 3 additions & 9 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from typing import Optional, Union
from warnings import warn
from typing import Any, Optional, Tuple

import torch
from torch import Tensor
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


Expand Down Expand Up @@ -67,7 +60,8 @@ def __init__(self, quant_layer, quant_injector):
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(self, qt_args, x=None):
def create_quant_tensor(
self, qt_args: Tuple[Any], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
if x is None:
out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training)
else:
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Any, List, Union

from torch import Tensor
from typing import Any, Tuple

from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
Expand All @@ -16,7 +14,7 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseFloatQuantTensor]:
def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
out,
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Optional, Tuple

from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat
Expand All @@ -17,7 +19,10 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def create_quant_tensor(self, qt_args, x=None):
def create_quant_tensor(
self,
qt_args: Tuple[Any],
x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor:
if x is None:
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
out = GroupwiseFloatQuantTensor(
Expand Down
9 changes: 2 additions & 7 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import Any, List, Optional, Union

import torch
from torch import Tensor
from typing import Any, List

from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import GroupwiseIntQuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseInt


class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector):
Expand All @@ -19,7 +14,7 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def create_quant_tensor(self, qt_args: List[Any]) -> Union[Tensor, GroupwiseIntQuantTensor]:
def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor:
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
out,
Expand Down
11 changes: 5 additions & 6 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Union

import torch
from torch import Tensor
from typing import Any, Optional, Tuple

from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.quant_tensor import GroupwiseIntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseInt


Expand All @@ -23,7 +19,10 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def create_quant_tensor(self, qt_args, x=None):
def create_quant_tensor(
self,
qt_args: Tuple[Any],
x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor:
if x is None:
value, scale, zero_point, bit_width, = qt_args
out = GroupwiseIntQuantTensor(
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def bit_width(self):
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]:
def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor:
return IntQuantTensor(*qt_args, self.is_signed, self.training)


Expand All @@ -208,7 +208,7 @@ def pre_zero_point(self):
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point

def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]:
def create_quant_tensor(self, qt_args: Tuple[Any]) -> IntQuantTensor:
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = qt_args
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)

Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABCMeta
from abc import abstractmethod
from typing import Optional

from torch import nn
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC
from abc import abstractmethod
from typing import Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

from torch import nn
from torch import Tensor
Expand Down Expand Up @@ -197,7 +197,9 @@ def zero_point(self, force_eval=True):
def bit_width(self, force_eval=True):
return self.retrieve_attribute('bit_width', force_eval)

def create_quant_tensor(self, qt_args, x=None):
def create_quant_tensor(
self, qt_args: Tuple[Any], x: Optional[IntQuantTensor] = None) -> IntQuantTensor:

if x is None:
out = IntQuantTensor(*qt_args, self.is_signed, self.training)
else:
Expand Down

0 comments on commit 8111ae0

Please sign in to comment.