Skip to content

Commit

Permalink
Feat (nn): add QuantConv3d and QuantConv3dTranspose (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev authored Mar 7, 2024
1 parent 0a023aa commit d981aa9
Show file tree
Hide file tree
Showing 26 changed files with 663 additions and 34 deletions.
32 changes: 32 additions & 0 deletions src/brevitas/export/onnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ def kernel_shape(module):
return list(module.kernel_size)


class Kernel3dApplHandlerMixin(ABC):

@staticmethod
def padding(module):
if isinstance(module.padding, int):
padding = [module.padding] * 6
else:
padding = list(module.padding) + list(module.padding)
return padding

@staticmethod
def stride(module):
if isinstance(module.stride, int):
return [module.stride] * 3
else:
return list(module.stride)

@staticmethod
def dilation(module):
if isinstance(module.dilation, int):
return [module.dilation] * 3
else:
return list(module.dilation)

@staticmethod
def kernel_shape(module):
if isinstance(module.kernel_size, int):
return [module.kernel_size] * 3
else:
return list(module.kernel_size)


class ONNXBaseHandler(BaseHandler, ABC):

def __init__(self):
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from brevitas.export.common import to_0dim_if_scalar
from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin
from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin
from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin
from brevitas.export.onnx.standard.function import DequantizeLinearFn
from brevitas.export.onnx.standard.function import IntClipFn
from brevitas.export.onnx.standard.function import QuantizeLinearFn
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

Expand Down Expand Up @@ -56,7 +58,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_quant_bias_signed
cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)

def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
self.validate(module)

op_symbolic_kwargs = self.op_symbolic_kwargs(module)
Expand Down Expand Up @@ -106,7 +108,7 @@ def output_symbolic_execution(self, out: Tensor):

class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):

def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_zero_point': self.quant_input_zero_point(module),
Expand All @@ -131,6 +133,10 @@ def op_symbolic_execution(self, inp: Tensor):
return out


class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin):
handled_layer = QuantConv3d


class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d

Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/export/onnx/standard/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .handler.base import StdQOpONNXQuantLayerHandler
from .handler.parameter import StdQOpONNXQuantConv1dHandler
from .handler.parameter import StdQOpONNXQuantConv2dHandler
from .handler.parameter import StdQOpONNXQuantConv3dHandler
from .handler.parameter import StdQOpONNXQuantLinearHandler


Expand All @@ -39,6 +40,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
handlers = [
StdQOpONNXQuantConv1dHandler,
StdQOpONNXQuantConv2dHandler,
StdQOpONNXQuantConv3dHandler,
StdQOpONNXQuantLinearHandler,
StdQOpONNXQuantReLUHandler,
StdQOpONNXQuantHardTanhHandler,
Expand Down
11 changes: 10 additions & 1 deletion src/brevitas/export/torch/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

Expand Down Expand Up @@ -93,7 +94,7 @@ def explicit_output_dtype(cls):
return True

@classmethod
def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]):
def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
return {
'bias': cls.prepare_bias(module),
'stride': module.stride,
Expand All @@ -119,6 +120,14 @@ def prepare_qf(cls, module: QuantConv2d):
return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module)


class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler):
handled_layer = QuantConv3d

@classmethod
def prepare_qf(cls, module: QuantConv3d):
return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module)


class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler):
handled_layer = QuantLinear

Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/export/torch/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .handler.act import PytorchQuantReLUHandler
from .handler.parameter import PytorchQuantConv1dHandler
from .handler.parameter import PytorchQuantConv2dHandler
from .handler.parameter import PytorchQuantConv3dHandler
from .handler.parameter import PytorchQuantLinearHandler


Expand All @@ -31,6 +32,7 @@ class TorchQOpManager(BaseManager):
PytorchQuantReLUHandler,
PytorchQuantConv1dHandler,
PytorchQuantConv2dHandler,
PytorchQuantConv3dHandler,
PytorchQuantLinearHandler]

@classmethod
Expand Down
11 changes: 8 additions & 3 deletions src/brevitas/graph/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform):
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantLinear,
qnn.QuantConv1d,
qnn.QuantConv2d,
qnn.QuantConv3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d)
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d)

def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES):
super(MoveSplitBatchNormBeforeCat, self).__init__()
Expand Down Expand Up @@ -93,8 +96,10 @@ class MergeBatchNorm(UntilFixedPointGraphTransform):
nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias,
nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d),
(qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d),
(qnn.QuantConvTranspose1d,
nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d))
(qnn.QuantConv3d,
nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d),
(qnn.QuantConvTranspose2d,
nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d))

def __init__(self, patterns=DEFAULT_PATTERNS):
super(MergeBatchNorm, self).__init__()
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def update_batch(self, module, input, current_layer):

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd
Expand Down Expand Up @@ -239,7 +241,9 @@ def single_layer_update(self):
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update_batch(self, module, input, current_layer):

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd
Expand Down Expand Up @@ -204,7 +206,9 @@ def single_layer_update(self, percdamp=.01):
dtype = weight.dtype

if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)

Expand Down
11 changes: 9 additions & 2 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import brevitas.nn as qnn

SUPPORTED_CONV_OP = (
qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)
qnn.QuantConv1d,
qnn.QuantConv2d,
qnn.QuantConv3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d)


class StopFwdException(Exception):
Expand Down Expand Up @@ -196,7 +201,9 @@ def __init__(
# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
self.groups = self.layer.groups
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/graph/per_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.graph.utils import replace_module
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d

from .base import PerInputModuleToModuleByHook

Expand Down Expand Up @@ -93,9 +94,10 @@ def replace_modules(self, model):
dw_conv = QuantConv1d(**kwargs)
elif isinstance(avgpool, nn.AvgPool2d):
dw_conv = QuantConv2d(**kwargs)
elif isinstance(avgpool, nn.AvgPool3d):
dw_conv = QuantConv3d(**kwargs)
else:
assert isinstance(avgpool, nn.AvgPool3d)
raise RuntimeError("QuantConv3d not supported yet.")
raise RuntimeError("Unsupported operation.")
kernel_value = 1. / reduce(mul, dw_conv.kernel_size)
dw_conv.register_parameter(
'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value)))
Expand Down
26 changes: 26 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -94,6 +100,12 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.Linear: (
qnn.QuantLinear,
{
Expand Down Expand Up @@ -151,6 +163,13 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'input_quant': Int8ActPerTensorFloat,
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -165,6 +184,13 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'input_quant': Int8ActPerTensorFloat,
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.Linear: (
qnn.QuantLinear,
{
Expand Down
12 changes: 12 additions & 0 deletions src/brevitas/graph/target/flexml.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -57,6 +63,12 @@
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.BatchNorm1d: (
qnn.BatchNorm1dToQuantScaleBias,
{
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d]
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d]


def module_class_name(m: torch.nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from .quant_bn import BatchNorm2dToQuantScaleBias
from .quant_conv import QuantConv1d
from .quant_conv import QuantConv2d
from .quant_conv import QuantConv3d
from .quant_convtranspose import QuantConvTranspose1d
from .quant_convtranspose import QuantConvTranspose2d
from .quant_convtranspose import QuantConvTranspose3d
from .quant_eltwise import QuantCat
from .quant_eltwise import QuantEltwiseAdd
from .quant_embedding import QuantEmbedding
Expand Down
Loading

0 comments on commit d981aa9

Please sign in to comment.