diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py index f856eaeee..469c49907 100644 --- a/src/brevitas/export/onnx/handler.py +++ b/src/brevitas/export/onnx/handler.py @@ -88,6 +88,38 @@ def kernel_shape(module): return list(module.kernel_size) +class Kernel3dApplHandlerMixin(ABC): + + @staticmethod + def padding(module): + if isinstance(module.padding, int): + padding = [module.padding] * 6 + else: + padding = list(module.padding) + list(module.padding) + return padding + + @staticmethod + def stride(module): + if isinstance(module.stride, int): + return [module.stride] * 3 + else: + return list(module.stride) + + @staticmethod + def dilation(module): + if isinstance(module.dilation, int): + return [module.dilation] * 3 + else: + return list(module.dilation) + + @staticmethod + def kernel_shape(module): + if isinstance(module.kernel_size, int): + return [module.kernel_size] * 3 + else: + return list(module.kernel_size) + + class ONNXBaseHandler(BaseHandler, ABC): def __init__(self): diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index c4c552c65..6540b6f2d 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -8,11 +8,13 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin +from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin from brevitas.export.onnx.standard.function import DequantizeLinearFn from brevitas.export.onnx.standard.function import IntClipFn from brevitas.export.onnx.standard.function import QuantizeLinearFn from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -56,7 +58,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) - def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]): + def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): self.validate(module) op_symbolic_kwargs = self.op_symbolic_kwargs(module) @@ -106,7 +108,7 @@ def output_symbolic_execution(self, out: Tensor): class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): - def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): + def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': self.quant_input_zero_point(module), @@ -131,6 +133,10 @@ def op_symbolic_execution(self, inp: Tensor): return out +class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin): + handled_layer = QuantConv3d + + class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin): handled_layer = QuantConv2d diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 4c45df04c..174804407 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -26,6 +26,7 @@ from .handler.base import StdQOpONNXQuantLayerHandler from .handler.parameter import StdQOpONNXQuantConv1dHandler from .handler.parameter import StdQOpONNXQuantConv2dHandler +from .handler.parameter import StdQOpONNXQuantConv3dHandler from .handler.parameter import StdQOpONNXQuantLinearHandler from .handler.pool import StdQOpONNXQuantMaxPool1d from .handler.pool import StdQOpONNXQuantMaxPool2d @@ -48,6 +49,7 @@ class StdQOpONNXManager(StdONNXBaseManager): handlers = [ StdQOpONNXQuantConv1dHandler, StdQOpONNXQuantConv2dHandler, + StdQOpONNXQuantConv3dHandler, StdQOpONNXQuantLinearHandler, StdQOpONNXQuantReLUHandler, StdQOpONNXQuantHardTanhHandler, diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index fa110a84c..802a5a053 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -10,6 +10,7 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -93,7 +94,7 @@ def explicit_output_dtype(cls): return True @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]): + def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): return { 'bias': cls.prepare_bias(module), 'stride': module.stride, @@ -119,6 +120,14 @@ def prepare_qf(cls, module: QuantConv2d): return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module) +class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler): + handled_layer = QuantConv3d + + @classmethod + def prepare_qf(cls, module: QuantConv3d): + return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module) + + class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler): handled_layer = QuantLinear diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index aea97957c..7567e6770 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -19,6 +19,7 @@ from .handler.act import PytorchQuantReLUHandler from .handler.parameter import PytorchQuantConv1dHandler from .handler.parameter import PytorchQuantConv2dHandler +from .handler.parameter import PytorchQuantConv3dHandler from .handler.parameter import PytorchQuantLinearHandler from .handler.pool import PytorchQuantMaxPool1d from .handler.pool import PytorchQuantMaxPool2d @@ -35,6 +36,7 @@ class TorchQOpManager(BaseManager): PytorchQuantReLUHandler, PytorchQuantConv1dHandler, PytorchQuantConv2dHandler, + PytorchQuantConv3dHandler, PytorchQuantLinearHandler] @classmethod diff --git a/src/brevitas/graph/fixed_point.py b/src/brevitas/graph/fixed_point.py index afbd32e67..fead31ba2 100644 --- a/src/brevitas/graph/fixed_point.py +++ b/src/brevitas/graph/fixed_point.py @@ -33,11 +33,14 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform): nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, + nn.ConvTranspose3d, qnn.QuantLinear, qnn.QuantConv1d, qnn.QuantConv2d, + qnn.QuantConv3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d) + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES): super(MoveSplitBatchNormBeforeCat, self).__init__() @@ -93,8 +96,10 @@ class MergeBatchNorm(UntilFixedPointGraphTransform): nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias, nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d), (qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d), - (qnn.QuantConvTranspose1d, - nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d)) + (qnn.QuantConv3d, + nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d), + (qnn.QuantConvTranspose2d, + nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d)) def __init__(self, patterns=DEFAULT_PATTERNS): super(MergeBatchNorm, self).__init__() diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index e255660a0..46a58ebe2 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -169,7 +169,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -233,7 +235,9 @@ def single_layer_update(self): dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 28cb12cd6..6e9bf3497 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -152,7 +152,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -199,7 +201,9 @@ def single_layer_update(self, percdamp=.01): dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 149e8ec03..3eb2ab97a 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -20,7 +20,12 @@ from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( - qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) + qnn.QuantConv1d, + qnn.QuantConv2d, + qnn.QuantConv3d, + qnn.QuantConvTranspose1d, + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) class StopFwdException(Exception): @@ -193,7 +198,9 @@ def __init__( # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) self.groups = self.layer.groups diff --git a/src/brevitas/graph/per_input.py b/src/brevitas/graph/per_input.py index 37f5f6eac..908c73a2d 100644 --- a/src/brevitas/graph/per_input.py +++ b/src/brevitas/graph/per_input.py @@ -10,6 +10,7 @@ from brevitas.graph.utils import replace_module from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from .base import PerInputModuleToModuleByHook @@ -93,9 +94,10 @@ def replace_modules(self, model): dw_conv = QuantConv1d(**kwargs) elif isinstance(avgpool, nn.AvgPool2d): dw_conv = QuantConv2d(**kwargs) + elif isinstance(avgpool, nn.AvgPool3d): + dw_conv = QuantConv3d(**kwargs) else: - assert isinstance(avgpool, nn.AvgPool3d) - raise RuntimeError("QuantConv3d not supported yet.") + raise RuntimeError("Unsupported operation.") kernel_value = 1. / reduce(mul, dw_conv.kernel_size) dw_conv.register_parameter( 'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value))) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 4ac324c8b..b1b94b5da 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -82,6 +82,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -94,6 +100,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.Linear: ( qnn.QuantLinear, { @@ -151,6 +163,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -165,6 +184,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.Linear: ( qnn.QuantLinear, { diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index 9aedd337c..35b9a4d14 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,6 +45,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -57,6 +63,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index 2ddfc9b2c..ed00de3eb 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -31,7 +31,8 @@ nn.ConvTranspose2d, nn.ConvTranspose3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d] + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d] def module_class_name(m: torch.nn.Module): diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index cc96889b6..4d58ab66b 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -15,8 +15,10 @@ from .quant_bn import BatchNorm2dToQuantScaleBias from .quant_conv import QuantConv1d from .quant_conv import QuantConv2d +from .quant_conv import QuantConv3d from .quant_convtranspose import QuantConvTranspose1d from .quant_convtranspose import QuantConvTranspose2d +from .quant_convtranspose import QuantConvTranspose3d from .quant_eltwise import QuantCat from .quant_eltwise import QuantEltwiseAdd from .quant_embedding import QuantEmbedding diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 5af432af7..74912af67 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Conv1d from torch.nn import Conv2d +from torch.nn import Conv3d from torch.nn import functional as F from brevitas.function.ops import max_int @@ -20,7 +21,7 @@ from .quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from .quant_layer import WeightQuantType -__all__ = ['QuantConv1d', 'QuantConv2d'] +__all__ = ['QuantConv1d', 'QuantConv2d', 'QuantConv3d'] class QuantConv1d(QuantWBIOL, Conv1d): @@ -175,8 +176,9 @@ def per_elem_ops(self): @property def output_channel_dim(self): if self.transposed: - raise RuntimeError("Transposed kernels not supported") - return 0 + return 1 + else: + return 0 @property def channelwise_separable(self) -> bool: @@ -211,3 +213,113 @@ def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width + + +class QuantConv3d(QuantWBIOL, Conv3d): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + padding_mode: str = 'zeros', + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False + Conv3d.__init__( + self, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + device=device, + dtype=dtype) + QuantWBIOL.__init__( + self, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + **kwargs) + self.is_same_padded_strided = is_same_padded_strided + + @property + def per_elem_ops(self): + flat_kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] + return 2 * flat_kernel_size * (self.in_channels // self.groups) + + @property + def output_channel_dim(self): + if self.transposed: + return 1 + else: + return 0 + + @property + def channelwise_separable(self) -> bool: + # according to https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html + # if groups == in_channels that means each channel is convolved with its own set of filters + return self.groups == self.channels + + def conv3d_same_zeros_pad_stride(self, x, weight, bias): + id, ih, iw = x.size()[-3:] + kd, kh, kw = weight.size()[-3:] + sd, sh, sw = self.stride + od, oh, ow = math.ceil(id / sd), math.ceil(ih / sh), math.ceil(iw / sw) + pad_d = max((od - 1) * self.stride[0] + (kd - 1) * self.dilation[0] + 1 - id, 0) + pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, + [ + pad_w // 2, + pad_w - pad_w // 2, + pad_h // 2, + pad_h - pad_h // 2, + pad_d // 2, + pad_d - pad_d // 2]) + out = F.conv3d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + return out + + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below + return self.forward_impl(input) + + # override of QuantWBIOL method, called by QuantWBIOL.forward_impl + def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): + if self.is_same_padded_strided: + return self.conv3d_same_zeros_pad_stride(x, quant_weight, quant_bias) + else: + return self._conv_forward(x, quant_weight, quant_bias) + + def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) + group_size = self.in_channels // self.groups + kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] + max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 64fbe8eb6..75dd90378 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import math from typing import Optional, Tuple, Type, Union from packaging import version @@ -8,8 +9,10 @@ from torch import Tensor from torch.nn import ConvTranspose1d from torch.nn import ConvTranspose2d +from torch.nn import ConvTranspose3d from torch.nn.functional import conv_transpose1d from torch.nn.functional import conv_transpose2d +from torch.nn.functional import conv_transpose3d from brevitas import torch_version from brevitas.function.ops import max_int @@ -22,7 +25,7 @@ from .quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from .quant_layer import WeightQuantType -__all__ = ['QuantConvTranspose1d', 'QuantConvTranspose2d'] +__all__ = ['QuantConvTranspose1d', 'QuantConvTranspose2d', 'QuantConvTranspose3d'] class QuantConvTranspose1d(QuantWBIOL, ConvTranspose1d): @@ -116,8 +119,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -213,8 +216,107 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + math.ceil(self.kernel_size[1] / self.stride[1]), 1) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +class QuantConvTranspose3d(QuantWBIOL, ConvTranspose3d): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int]] = 0, + output_padding: Union[int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + padding_mode: str = 'zeros', + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **kwargs) -> None: + ConvTranspose3d.__init__( + self, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + device=device, + dtype=dtype) + QuantWBIOL.__init__( + self, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + **kwargs) + self._output_size = None + + @property + def per_elem_ops(self): + raise NotImplementedError + + @property + def output_channel_dim(self) -> int: + return 1 + + @property + def channelwise_separable(self) -> bool: + raise self.groups == self.out_channels + + def forward(self, + input: Union[Tensor, QuantTensor], + output_size=None) -> Union[Tensor, QuantTensor]: + self._output_size = output_size # cache the value temporarily + return self.forward_impl(input) + + def compute_output_padding(self, inp, output_size): + if torch_version >= version.parse('1.12'): + return self._output_padding( + inp, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims=3) + else: + return self._output_padding( + inp, output_size, self.stride, self.padding, self.kernel_size) + + def conv_transpose3d_zeros_pad( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): + out = conv_transpose3d( + x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + return out + + def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): + if self.padding_mode == 'zeros': + output_padding = self.compute_output_padding(x, self._output_size) + self._output_size = None # set it back to None after consuming it + out = self.conv_transpose3d_zeros_pad(x, quant_weight, quant_bias, output_padding) + return out + else: + raise NotImplementedError(f"Padding mode {self.padding_mode} not supported.") + + def max_acc_bit_width(self, input_bit_width, weight_bit_width): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) + group_size = self.out_channels // self.groups + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + math.ceil(self.kernel_size[1] / self.stride[1]), 1) * max( + math.ceil(self.kernel_size[2] / self.stride[2]), 1) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 8bfac6fd2..31524729f 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -9,8 +9,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn import TruncAvgPool2d @@ -53,7 +55,13 @@ 'bias_external_scale': (Int32Bias,), 'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d,] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 4f737f8d7..6019bf417 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -36,8 +36,10 @@ def test_torch_qcdq_wbiol_export( in_size = (1, IN_CH) elif quant_module_impl == QuantConv1d or quant_module_impl == QuantConvTranspose1d: in_size = (1, IN_CH, FEATURES) - else: + elif quant_module_impl == QuantConv2d or quant_module_impl == QuantConvTranspose2d: in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = torch.randn(in_size) quant_module(inp) # Collect scale factors diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index a4865ac24..c58d9d828 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -10,12 +10,16 @@ from torchvision import models from brevitas.fx import symbolic_trace +from brevitas.graph import AvgPoolToQuantDepthwiseConv from brevitas.graph import DuplicateSharedStatelessModule from brevitas.graph import FnToModule from brevitas.graph import MeanMethodToAdaptiveAvgPool2d from brevitas.graph import MergeBatchNorm from brevitas.graph import MethodToModule from brevitas.graph.base import ModuleToModuleByInstance +from brevitas.nn import QuantConv1d +from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d SEED = 123456 INPUT_SIZE = (1, 3, 224, 224) @@ -50,6 +54,86 @@ def test_rewriter_merge_bn(model_name: str, pretrained: bool): assert is_close +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_conv_merge_bn(dims): + + class TestModel(nn.Module): + + def __init__(self, dims): + super(TestModel, self).__init__() + layers = [] + + if dims == 1: + layers.append(nn.Conv1d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm1d(33)) + elif dims == 2: + layers.append(nn.Conv2d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm2d(33)) + else: + layers.append(nn.Conv3d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm3d(33)) + + layers.append(nn.ReLU()) + + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + model = TestModel(dims) + graph = symbolic_trace(model) + graph = MergeBatchNorm().apply(graph) + + for m in graph.modules(): + if dims == 1: + assert not isinstance(m, nn.BatchNorm1d) + elif dims == 2: + assert not isinstance(m, nn.BatchNorm2d) + else: + assert not isinstance(m, nn.BatchNorm3d) + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_avg_pool_to_quant_conv(dims): + + class TestModel(nn.Module): + + def __init__(self, dims): + super(TestModel, self).__init__() + + if dims == 1: + self.net = nn.Sequential(nn.AvgPool1d(3, stride=2), nn.ReLU()) + elif dims == 2: + self.net = nn.Sequential(nn.AvgPool2d(3, stride=2), nn.ReLU()) + else: + self.net = nn.Sequential(nn.AvgPool3d(3, stride=2), nn.ReLU()) + + def forward(self, x): + return self.net(x) + + model = TestModel(dims) + + args = None + if dims == 1: + args = torch.randn(20, 16, 10) + elif dims == 2: + args = torch.randn(20, 16, 10, 50) + else: + args = torch.randn(20, 16, 10, 50, 100) + + graph = symbolic_trace(model) + graph = AvgPoolToQuantDepthwiseConv().apply(graph, args) + + has_quant_conv = False + for m in graph.modules(): + if isinstance(m, (QuantConv1d, QuantConv2d, QuantConv3d)): + has_quant_conv = True + + assert not isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)) + + assert has_quant_conv + + def test_rewriter_duplicate_shared_relu(): class TestModel(nn.Module): diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 538e836e8..6f0ddd6e0 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -12,8 +12,10 @@ from brevitas import torch_version from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn.quant_mha import QuantMultiheadAttention @@ -91,8 +93,10 @@ QuantLinear, QuantConv1d, QuantConv2d, + QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d,] + QuantConvTranspose2d, + QuantConvTranspose3d,] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] @@ -157,8 +161,12 @@ def forward(self, x): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) + elif impl in ('QuantConv3d', 'QuantConvTranspose3d'): + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + else: + raise RuntimeError("Unsupported operation") if input_quantized: quant_inp = QuantTensor( diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index 8c34f390c..fa8dd701a 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -100,14 +100,20 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs['model_type'] == 'QuantConv1d': # shape = (out_channels, in_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2)) elif kwargs[ - 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_size, kernel_size) + 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_depth, kernel_height, kernel_width) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3, 4)) elif kwargs[ 'model_type'] == 'QuantConvTranspose1d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2)) elif kwargs[ - 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) + 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_depth, kernel_height, kernel_width) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py new file mode 100644 index 000000000..fdd35524f --- /dev/null +++ b/tests/brevitas/nn/test_conv3d.py @@ -0,0 +1,103 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +from torch.nn import BatchNorm3d +from torch.nn import Conv3d + +from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling +from brevitas.nn import QuantConv3d +from brevitas.nn.utils import merge_bn + +OUTPUT_CHANNELS = 10 +INPUT_CHANNELS = 5 +KERNEL_SIZE = (3, 3, 3) +WEIGHT_BIT_WIDTH = 5 + + +class TestQuantConv3d: + + def test_module_init(self): + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + + def test_fp_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_type='FP', + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_none_weight_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant=None, + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_delayed_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_internally_scaled_int_bias(self): + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=True, + bias_quant=Int8BiasPerTensorFloatInternalScaling) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + mod(inp) + + def test_internally_scaled_int_bias_after_bn_merge(self): + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=False, + bias_quant=Int8BiasPerTensorFloatInternalScaling) + bn = BatchNorm3d(OUTPUT_CHANNELS) + merge_bn(mod, bn) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + mod(inp) diff --git a/tests/brevitas/nn/test_convtranspose.py b/tests/brevitas/nn/test_convtranspose.py new file mode 100644 index 000000000..deb585db0 --- /dev/null +++ b/tests/brevitas/nn/test_convtranspose.py @@ -0,0 +1,69 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn + +from brevitas.nn import QuantConvTranspose1d +from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d + + +def test_quantconvtranspose1d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 50) + + normal = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose1d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() + + +def test_quantconvtranspose2d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 50, 100) + + normal = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose2d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() + + +def test_quantconvtranspose3d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 10, 50, 100) + + normal = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose3d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 3577d25a1..58b9a86ca 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -7,8 +7,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear from brevitas.nn import QuantScaleBias from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -25,7 +27,13 @@ IN_CH = 5 KERNEL_SIZE = 3 -QUANT_CONV_VARIANTS = [QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] +QUANT_CONV_VARIANTS = [ + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d,] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index a7f87cbef..4c148e96b 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -14,8 +14,10 @@ from brevitas.export import export_qonnx from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -69,7 +71,13 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d,] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 0b1277686..fce766d2d 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,7 +30,8 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', + 'QuantConvTranspose3d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') @@ -44,8 +45,12 @@ def test_ort_wbiol(model, export_type, current_cases): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) + elif impl in ('QuantConv3d', 'QuantConvTranspose3d'): + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + else: + raise RuntimeError("Unsupported operation") inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size)