Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2023
1 parent 84d25a6 commit 4f2a20b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ def assert_ge_zero(self, *args):
assert bools


class CDQMixin(DQMixin, ABC):
class DQCastMixin(DQMixin, ABC):

@abstractmethod
def clip_fn(self, x, min_val, max_val):
def cast_fn(self, x, dtype):
pass


class CastMixin(ABC):
class CDQMixin(DQMixin, ABC):

@abstractmethod
def cast_fn(self, x, dtype):
def clip_fn(self, x, min_val, max_val):
pass


Expand Down
14 changes: 5 additions & 9 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import torch

from brevitas.export.common.handler.qcdq import CastMixin
from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
Expand All @@ -24,11 +23,14 @@
from ..function import QuantizeLinearFn


class StdDQONNXMixin(DQMixin, ABC):
class StdDQCastONNXMixin(DQCastMixin, ABC):

def dequantize_fn(self, x, scale, zero_point, axis):
return DequantizeLinearFn.apply(x, scale, zero_point, axis)

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)

@property
def flatten_dequantize_params(self):
return True
Expand All @@ -41,12 +43,6 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdDQCastONNXMixin(CastMixin, StdDQONNXMixin):

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)


class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
Expand Down
28 changes: 13 additions & 15 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
Expand All @@ -23,7 +23,7 @@ def _itemize_clip_bounds(clip_args):
return clip_args


class TorchDQMixin(DQMixin, ABC):
class TorchDQCastMixin(DQCastMixin, ABC):

def __init__(self) -> None:
super().__init__()
Expand All @@ -40,6 +40,9 @@ def dequantize_fn(self, x, scale, zero_point, axis):
zero_point = float(zero_point)
return (x - zero_point) * scale

def cast_fn(self, x, dtype):
return x.type(dtype)

@property
def flatten_dequantize_params(self):
return False
Expand All @@ -52,19 +55,13 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class TorchDQCastMixin(TorchDQMixin, ABC):

def cast_fn(self, x, dtype):
return x.type(dtype)


class TorchCDQMixin(TorchDQCastMixin, ABC):
class TorchCDQCastMixin(TorchDQCastMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return torch.clamp(x, min_val, max_val)


class TorchQCDQMixin(QMixin, TorchCDQMixin, ABC):
class TorchQCDQCastMixin(QMixin, TorchCDQCastMixin, ABC):

@classmethod
def int8_dtype(cls):
Expand Down Expand Up @@ -96,7 +93,7 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchQCDQWeightQuantProxyHandler(TorchCDQMixin,
class TorchQCDQWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

Expand All @@ -106,7 +103,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQMixin,
class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

Expand All @@ -117,15 +114,16 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQDecoupledWeightQuantWithInputProxyHandler(
TorchCDQMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQActQuantProxyHandler(TorchQCDQMixin, QCDQActQuantProxyHandlerMixin,
class TorchQCDQActQuantProxyHandler(TorchQCDQCastMixin,
QCDQActQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -140,7 +138,7 @@ class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin,
pass


class TorchQCDQTruncQuantProxyHandler(TorchQCDQMixin,
class TorchQCDQTruncQuantProxyHandler(TorchQCDQCastMixin,
QCDQTruncQuantProxyHandlerMixin,
TorchQCDQHandler):

Expand Down

0 comments on commit 4f2a20b

Please sign in to comment.