Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add QuantConv3d and QuantConv3dTranspose #805

Merged
merged 42 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5979acb
first pass attempt at implementing QuantConv3D
costigt-dev Jan 19, 2024
9ecea98
placeholder implementation for QuantConvTranspose3d
costigt-dev Jan 19, 2024
091d70d
first implementation of QuantConvTranspose3d
costigt-dev Jan 19, 2024
44dc026
added new conv3d classes to the __init__.py
costigt-dev Jan 22, 2024
9c5fff5
added space to QuantConv3d to be close to other classes in file
costigt-dev Jan 22, 2024
6121a10
adapted conv2d to conv3d
costigt-dev Jan 22, 2024
c568895
removed is_same_padded_strided and its accompanying function as it is…
costigt-dev Jan 23, 2024
cf86c8b
formatting fixes
costigt-dev Jan 24, 2024
c62e07b
Revert "removed is_same_padded_strided and its accompanying function …
costigt-dev Jan 29, 2024
ff281b8
updated references to QuantConv and QuantConvTranpose thoughout code …
costigt-dev Jan 29, 2024
c749625
Merge branch 'dev' into feat/conv3d
costigt-dev Jan 29, 2024
a02aa23
Revert "updated references to QuantConv and QuantConvTranpose thougho…
costigt-dev Jan 30, 2024
10b331e
pre-commit hook changes
costigt-dev Jan 30, 2024
57d8d8b
updated references to QuantConv and QuantConvTranpose thoughout code …
costigt-dev Jan 29, 2024
d71762e
removing unused import
costigt-dev Jan 30, 2024
21d2d08
Merge branch 'dev' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Jan 30, 2024
f6559cc
added condition for quantconv2d and made default case conv3d
costigt-dev Jan 31, 2024
bab6042
disable QuantConvTranspose3d in tests
costigt-dev Jan 31, 2024
22fe5c2
restored function necessary for QuantConv3d to be tested
costigt-dev Jan 31, 2024
ee01b84
restored QuantConvTranspose3d to tests
costigt-dev Jan 31, 2024
04ca06c
pre-commit changes
costigt-dev Jan 31, 2024
5ad14d4
Merge branch 'dev' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Feb 1, 2024
382c50b
fixed missing parts in test code causing test failures
costigt-dev Feb 1, 2024
a068004
pre-commit hook changes
costigt-dev Feb 1, 2024
f30e36f
fixed typo - should be conv3d
costigt-dev Feb 7, 2024
f35cf4b
removed quantconv3d from flexml.py as it is unnecessary
costigt-dev Feb 8, 2024
8e7b647
made check for 3d version explicit instead of default case
costigt-dev Feb 8, 2024
479eb9f
added tests for conv1d,2d,3d merge batch norm
costigt-dev Feb 8, 2024
51771bc
added tests to check if avgpool is replace with quantconvs and that m…
costigt-dev Feb 8, 2024
4828856
reordered items in SUPPORTED_CONV_OP
costigt-dev Feb 12, 2024
82ea0da
collapsed into isInstance
costigt-dev Feb 12, 2024
9c40dc0
Merge branch 'master' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Feb 20, 2024
b827129
resolved merge conflict
costigt-dev Feb 20, 2024
ea729c8
Revert "removed quantconv3d from flexml.py as it is unnecessary"
costigt-dev Feb 22, 2024
a8ad695
correct incorrect value in Kernel3dApplHandlerMixin from 4 to 3
costigt-dev Feb 27, 2024
4c76e4f
updated comments for tensor shapes
costigt-dev Mar 6, 2024
5b94089
updated convtranspose method based on PR suggestion
costigt-dev Mar 6, 2024
b607c8f
added max(...,1) to patch_size calculation
costigt-dev Mar 6, 2024
6271f40
added some basic tests for convtranspose
costigt-dev Mar 6, 2024
6a651c3
updated copyright year
costigt-dev Mar 6, 2024
f4f92e6
changed rounding from floor equivalent to ceil
costigt-dev Mar 6, 2024
3e5db0a
switched torch.ceil to math.ceil
costigt-dev Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/brevitas/export/onnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ def kernel_shape(module):
return list(module.kernel_size)


class Kernel3dApplHandlerMixin(ABC):

@staticmethod
def padding(module):
if isinstance(module.padding, int):
padding = [module.padding] * 6
else:
padding = list(module.padding) + list(module.padding)
return padding

@staticmethod
def stride(module):
if isinstance(module.stride, int):
return [module.stride] * 3
else:
return list(module.stride)

@staticmethod
def dilation(module):
if isinstance(module.dilation, int):
return [module.dilation] * 3
else:
return list(module.dilation)

@staticmethod
def kernel_shape(module):
if isinstance(module.kernel_size, int):
return [module.kernel_size] * 3
else:
return list(module.kernel_size)


class ONNXBaseHandler(BaseHandler, ABC):

def __init__(self):
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from brevitas.export.common import to_0dim_if_scalar
from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin
from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin
from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin
from brevitas.export.onnx.standard.function import DequantizeLinearFn
from brevitas.export.onnx.standard.function import IntClipFn
from brevitas.export.onnx.standard.function import QuantizeLinearFn
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

Expand Down Expand Up @@ -56,7 +58,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_quant_bias_signed
cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)

def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
self.validate(module)

op_symbolic_kwargs = self.op_symbolic_kwargs(module)
Expand Down Expand Up @@ -106,7 +108,7 @@ def output_symbolic_execution(self, out: Tensor):

class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):

def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
conv_symbolic_kwargs = {
'input_scale': module.quant_input_scale(),
'input_zero_point': self.quant_input_zero_point(module),
Expand All @@ -131,6 +133,10 @@ def op_symbolic_execution(self, inp: Tensor):
return out


class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin):
handled_layer = QuantConv3d


class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d

Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/export/onnx/standard/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .handler.base import StdQOpONNXQuantLayerHandler
from .handler.parameter import StdQOpONNXQuantConv1dHandler
from .handler.parameter import StdQOpONNXQuantConv2dHandler
from .handler.parameter import StdQOpONNXQuantConv3dHandler
from .handler.parameter import StdQOpONNXQuantLinearHandler
from .handler.pool import StdQOpONNXQuantMaxPool1d
from .handler.pool import StdQOpONNXQuantMaxPool2d
Expand All @@ -48,6 +49,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
handlers = [
StdQOpONNXQuantConv1dHandler,
StdQOpONNXQuantConv2dHandler,
StdQOpONNXQuantConv3dHandler,
StdQOpONNXQuantLinearHandler,
StdQOpONNXQuantReLUHandler,
StdQOpONNXQuantHardTanhHandler,
Expand Down
11 changes: 10 additions & 1 deletion src/brevitas/export/torch/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

Expand Down Expand Up @@ -93,7 +94,7 @@ def explicit_output_dtype(cls):
return True

@classmethod
def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]):
def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]):
return {
'bias': cls.prepare_bias(module),
'stride': module.stride,
Expand All @@ -119,6 +120,14 @@ def prepare_qf(cls, module: QuantConv2d):
return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module)


class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler):
handled_layer = QuantConv3d

@classmethod
def prepare_qf(cls, module: QuantConv3d):
return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module)


class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler):
handled_layer = QuantLinear

Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/export/torch/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .handler.act import PytorchQuantReLUHandler
from .handler.parameter import PytorchQuantConv1dHandler
from .handler.parameter import PytorchQuantConv2dHandler
from .handler.parameter import PytorchQuantConv3dHandler
from .handler.parameter import PytorchQuantLinearHandler
from .handler.pool import PytorchQuantMaxPool1d
from .handler.pool import PytorchQuantMaxPool2d
Expand All @@ -35,6 +36,7 @@ class TorchQOpManager(BaseManager):
PytorchQuantReLUHandler,
PytorchQuantConv1dHandler,
PytorchQuantConv2dHandler,
PytorchQuantConv3dHandler,
PytorchQuantLinearHandler]

@classmethod
Expand Down
11 changes: 8 additions & 3 deletions src/brevitas/graph/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform):
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantLinear,
qnn.QuantConv1d,
qnn.QuantConv2d,
qnn.QuantConv3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d)
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d)

def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES):
super(MoveSplitBatchNormBeforeCat, self).__init__()
Expand Down Expand Up @@ -93,8 +96,10 @@ class MergeBatchNorm(UntilFixedPointGraphTransform):
nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias,
nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d),
(qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d),
(qnn.QuantConvTranspose1d,
nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d))
(qnn.QuantConv3d,
nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d),
(qnn.QuantConvTranspose2d,
nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d))

def __init__(self, patterns=DEFAULT_PATTERNS):
super(MergeBatchNorm, self).__init__()
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def update_batch(self, module, input, current_layer):

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd
Expand Down Expand Up @@ -233,7 +235,9 @@ def single_layer_update(self):
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def update_batch(self, module, input, current_layer):

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd
Expand Down Expand Up @@ -199,7 +201,9 @@ def single_layer_update(self, percdamp=.01):
dtype = weight.dtype

if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)

Expand Down
11 changes: 9 additions & 2 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from brevitas.quant_tensor import QuantTensor

SUPPORTED_CONV_OP = (
qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)
qnn.QuantConv1d,
qnn.QuantConv2d,
qnn.QuantConv3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d)


class StopFwdException(Exception):
Expand Down Expand Up @@ -193,7 +198,9 @@ def __init__(
# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
self.groups = self.layer.groups
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/graph/per_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.graph.utils import replace_module
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d

from .base import PerInputModuleToModuleByHook

Expand Down Expand Up @@ -93,9 +94,10 @@ def replace_modules(self, model):
dw_conv = QuantConv1d(**kwargs)
elif isinstance(avgpool, nn.AvgPool2d):
dw_conv = QuantConv2d(**kwargs)
elif isinstance(avgpool, nn.AvgPool3d):
dw_conv = QuantConv3d(**kwargs)
else:
assert isinstance(avgpool, nn.AvgPool3d)
raise RuntimeError("QuantConv3d not supported yet.")
raise RuntimeError("Unsupported operation.")
kernel_value = 1. / reduce(mul, dw_conv.kernel_size)
dw_conv.register_parameter(
'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value)))
Expand Down
26 changes: 26 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -94,6 +100,12 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': True}),
nn.Linear: (
qnn.QuantLinear,
{
Expand Down Expand Up @@ -151,6 +163,13 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'input_quant': Int8ActPerTensorFloat,
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -165,6 +184,13 @@
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'input_quant': Int8ActPerTensorFloat,
'weight_quant': Int8WeightPerTensorFloat,
'bias_quant': Int32Bias,
'return_quant_tensor': False}),
nn.Linear: (
qnn.QuantLinear,
{
Expand Down
12 changes: 12 additions & 0 deletions src/brevitas/graph/target/flexml.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.Conv3d: (
qnn.QuantConv3d,
{
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
costigt-dev marked this conversation as resolved.
Show resolved Hide resolved
nn.ConvTranspose1d: (
qnn.QuantConvTranspose1d,
{
Expand All @@ -57,6 +63,12 @@
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.ConvTranspose3d: (
qnn.QuantConvTranspose3d,
{
'weight_quant': Int8WeightPerTensorFixedPoint,
'bias_quant': Int16Bias,
'return_quant_tensor': True}),
nn.BatchNorm1d: (
qnn.BatchNorm1dToQuantScaleBias,
{
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d]
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d]


def module_class_name(m: torch.nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from .quant_bn import BatchNorm2dToQuantScaleBias
from .quant_conv import QuantConv1d
from .quant_conv import QuantConv2d
from .quant_conv import QuantConv3d
from .quant_convtranspose import QuantConvTranspose1d
from .quant_convtranspose import QuantConvTranspose2d
from .quant_convtranspose import QuantConvTranspose3d
from .quant_eltwise import QuantCat
from .quant_eltwise import QuantEltwiseAdd
from .quant_embedding import QuantEmbedding
Expand Down
Loading
Loading