From 5979acb4dc24622ddc42546357a22f8af9eb49ea Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:16:34 +0000 Subject: [PATCH 01/37] first pass attempt at implementing QuantConv3D --- src/brevitas/nn/quant_conv.py | 108 +++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 846d4f290..191354fac 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Conv1d from torch.nn import Conv2d +from torch.nn import Conv3d from torch.nn import functional as F from brevitas.function.ops import max_int @@ -20,7 +21,7 @@ from .quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from .quant_layer import WeightQuantType -__all__ = ['QuantConv1d', 'QuantConv2d'] +__all__ = ['QuantConv1d', 'QuantConv2d', 'QuantConv3d'] class QuantConv1d(QuantWBIOL, Conv1d): @@ -175,8 +176,9 @@ def per_elem_ops(self): @property def output_channel_dim(self): if self.transposed: - raise RuntimeError("Transposed kernels not supported") - return 0 + return 1 + else: + return 0 @property def channelwise_separable(self) -> bool: @@ -211,3 +213,103 @@ def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width + +class QuantConv3d(QuantWBIOL, Conv3d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + padding_mode: str = 'zeros', + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False + Conv3d.__init__( + self, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + device=device, + dtype=dtype) + QuantWBIOL.__init__( + self, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + **kwargs) + self.is_same_padded_strided = is_same_padded_strided + + @property + def per_elem_ops(self): + flat_kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] + return 2 * flat_kernel_size * (self.in_channels // self.groups) + + @property + def output_channel_dim(self): + if self.transposed: + return 1 + else: + return 0 + + @property + def channelwise_separable(self) -> bool: + # according to https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html + # if groups == in_channels that means each channel is convolved with its own set of filters + return self.groups == self.channels + + def conv3d_same_zeros_pad_stride(self, x, weight, bias): + id, ih, iw = x.size()[-3:] + kd, kh, kw = weight.size()[-3:] + sd, sh, sw = self.stride + od, oh, ow = math.ceil(id / sd), math.ceil(ih / sh), math.ceil(iw / sw) + pad_d = max((od - 1) * self.stride[0] + (kd - 1) * self.dilation[0] + 1 - id, 0) + pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) + out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + return out + + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: + # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below + return self.forward_impl(input) + + # override of QuantWBIOL method, called by QuantWBIOL.forward_impl + def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): + if self.is_same_padded_strided: + return self.conv3d_same_zeros_pad_stride(x, quant_weight, quant_bias) + else: + return self._conv_forward(x, quant_weight, quant_bias) + + def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) + group_size = self.in_channels // self.groups + kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] + max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width \ No newline at end of file From 9ecea98777d257d9b9c1eaab41dad2d55c28568c Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:18:14 +0000 Subject: [PATCH 02/37] placeholder implementation for QuantConvTranspose3d --- src/brevitas/nn/quant_convtranspose.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index c5dbd52b9..0eae0c9cc 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -8,8 +8,10 @@ from torch import Tensor from torch.nn import ConvTranspose1d from torch.nn import ConvTranspose2d +from torch.nn import ConvTranspose3d from torch.nn.functional import conv_transpose1d from torch.nn.functional import conv_transpose2d +from torch.nn.functional import conv_transpose3d from brevitas import torch_version from brevitas.function.ops import max_int @@ -22,7 +24,7 @@ from .quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from .quant_layer import WeightQuantType -__all__ = ['QuantConvTranspose1d', 'QuantConvTranspose2d'] +__all__ = ['QuantConvTranspose1d', 'QuantConvTranspose2d', 'QuantConvTranspose3d'] class QuantConvTranspose1d(QuantWBIOL, ConvTranspose1d): @@ -218,3 +220,6 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width + +class QuantConvTranspose3d(QuantWBIOL, ConvTranspose3d): + pass \ No newline at end of file From 091d70d997b846fbf22cbf573a5a2c626caee7e5 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:57:24 +0000 Subject: [PATCH 03/37] first implementation of QuantConvTranspose3d --- src/brevitas/nn/quant_convtranspose.py | 96 +++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 0eae0c9cc..ed9d778d9 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -222,4 +222,98 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): return max_output_bit_width class QuantConvTranspose3d(QuantWBIOL, ConvTranspose3d): - pass \ No newline at end of file + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int]] = 0, + output_padding: Union[int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + padding_mode: str = 'zeros', + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **kwargs) -> None: + ConvTranspose3d.__init__( + self, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + device=device, + dtype=dtype) + QuantWBIOL.__init__( + self, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + **kwargs) + self._output_size = None + + @property + def per_elem_ops(self): + raise NotImplementedError + + @property + def output_channel_dim(self) -> int: + return 1 + + @property + def channelwise_separable(self) -> bool: + raise self.groups == self.out_channels + + def forward(self, + input: Union[Tensor, QuantTensor], + output_size=None) -> Union[Tensor, QuantTensor]: + self._output_size = output_size # cache the value temporarily + return self.forward_impl(input) + + def compute_output_padding(self, inp, output_size): + if torch_version >= version.parse('1.12'): + return self._output_padding( + inp, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims=3) + else: + return self._output_padding( + inp, output_size, self.stride, self.padding, self.kernel_size) + + def conv_transpose3d_zeros_pad( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): + out = conv_transpose3d( + x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + return out + + def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): + if self.padding_mode == 'zeros': + output_padding = self.compute_output_padding(x, self._output_size) + self._output_size = None # set it back to None after consuming it + out = self.conv_transpose3d_zeros_pad(x, quant_weight, quant_bias, output_padding) + return out + else: + raise NotImplementedError(f"Padding mode {self.padding_mode} not supported.") + + def max_acc_bit_width(self, input_bit_width, weight_bit_width): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) + group_size = self.out_channels // self.groups + overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) + overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) + overlapping_sums *= max(round(self.kernel_size[2] / self.stride[2]), 1) + max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width \ No newline at end of file From 44dc026d02c7559267783cb7411b2ac32fcb1293 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:37:51 +0000 Subject: [PATCH 04/37] added new conv3d classes to the __init__.py --- src/brevitas/nn/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index 7138da2bd..0db183c5a 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -15,8 +15,10 @@ from .quant_bn import BatchNorm2dToQuantScaleBias from .quant_conv import QuantConv1d from .quant_conv import QuantConv2d +from .quant_conv import QuantConv3d from .quant_convtranspose import QuantConvTranspose1d from .quant_convtranspose import QuantConvTranspose2d +from .quant_convtranspose import QuantConvTranspose3d from .quant_dropout import QuantDropout from .quant_eltwise import QuantCat from .quant_eltwise import QuantEltwiseAdd From 9c5fff5cf8deb69b661c3d50ee1c1e4943aad1e3 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:47:38 +0000 Subject: [PATCH 05/37] added space to QuantConv3d to be close to other classes in file --- src/brevitas/nn/quant_conv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 191354fac..9b9b28234 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -215,6 +215,7 @@ def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): return max_output_bit_width class QuantConv3d(QuantWBIOL, Conv3d): + def __init__( self, in_channels: int, From 6121a10d32ed797213a8903d2f7064b113fb682a Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:48:24 +0000 Subject: [PATCH 06/37] adapted conv2d to conv3d --- tests/brevitas/nn/test_conv3d.py | 101 +++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 tests/brevitas/nn/test_conv3d.py diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py new file mode 100644 index 000000000..3dc461177 --- /dev/null +++ b/tests/brevitas/nn/test_conv3d.py @@ -0,0 +1,101 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +from torch.nn import BatchNorm3d +from torch.nn import Conv3d + +from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling +from brevitas.nn import QuantConv3d +from brevitas.nn.utils import merge_bn + +OUTPUT_CHANNELS = 10 +INPUT_CHANNELS = 5 +KERNEL_SIZE = (3,3,3) +WEIGHT_BIT_WIDTH = 5 + +class TestQuantConv3d: + + def test_module_init(self): + mod = QuantConv3d(out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + + def test_fp_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_type='FP', + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_none_weight_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant=None, + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_delayed_quant_module(self): + float_mod = Conv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + quant_mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=False) + quant_mod.load_state_dict(float_mod.state_dict()) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + out_float = float_mod(inp) + out_quant = quant_mod(inp) + assert out_float.isclose(out_quant).all().item() + + def test_internally_scaled_int_bias(self): + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=True, + bias_quant=Int8BiasPerTensorFloatInternalScaling) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + mod(inp) + + def test_internally_scaled_int_bias_after_bn_merge(self): + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + weight_quant_delay_steps=1, + bias=False, + bias_quant=Int8BiasPerTensorFloatInternalScaling) + bn = BatchNorm3d(OUTPUT_CHANNELS) + merge_bn(mod, bn) + inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) + mod(inp) \ No newline at end of file From c568895d170c6a485814955e49e5b125a5ec7a7d Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 23 Jan 2024 09:16:44 +0000 Subject: [PATCH 07/37] removed is_same_padded_strided and its accompanying function as it is specifically for a mode that pytorch doesn't support --- src/brevitas/nn/quant_conv.py | 72 ++--------------------------------- 1 file changed, 3 insertions(+), 69 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 9b9b28234..3d0b9e904 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -45,12 +45,6 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: - # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: - padding = 0 - is_same_padded_strided = True - else: - is_same_padded_strided = False Conv1d.__init__( self, in_channels=in_channels, @@ -72,7 +66,6 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) - self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -89,25 +82,11 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels - def conv1d_same_zeros_pad_stride(self, x, weight, bias): - ih = x.size()[-1] - kh = weight.size()[-1] - sh = self.stride[0] - oh = math.ceil(ih / sh) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - if pad_h > 0: - x = F.pad(x, [pad_h // 2, pad_h - pad_h // 2]) - out = F.conv1d(x, weight, bias, self.stride, 0, self.dilation, self.groups) - return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - if self.is_same_padded_strided: - return self.conv1d_same_zeros_pad_stride(x, quant_weight, quant_bias) - else: - return self._conv_forward(x, quant_weight, quant_bias) + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) @@ -139,12 +118,6 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: - # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: - padding = 0 - is_same_padded_strided = True - else: - is_same_padded_strided = False Conv2d.__init__( self, in_channels=in_channels, @@ -166,7 +139,6 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) - self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -184,26 +156,11 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels - def conv2d_same_zeros_pad_stride(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]): - ih, iw = x.size()[-2:] - kh, kw = weight.size()[-2:] - sh, sw = self.stride - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) - return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - if self.is_same_padded_strided: - return self.conv2d_same_zeros_pad_stride(x, quant_weight, quant_bias) - else: - return self._conv_forward(x, quant_weight, quant_bias) + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) @@ -235,12 +192,6 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: - # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1: - padding = 0 - is_same_padded_strided = True - else: - is_same_padded_strided = False Conv3d.__init__( self, in_channels=in_channels, @@ -262,7 +213,6 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) - self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -282,29 +232,13 @@ def channelwise_separable(self) -> bool: # if groups == in_channels that means each channel is convolved with its own set of filters return self.groups == self.channels - def conv3d_same_zeros_pad_stride(self, x, weight, bias): - id, ih, iw = x.size()[-3:] - kd, kh, kw = weight.size()[-3:] - sd, sh, sw = self.stride - od, oh, ow = math.ceil(id / sd), math.ceil(ih / sh), math.ceil(iw / sw) - pad_d = max((od - 1) * self.stride[0] + (kd - 1) * self.dilation[0] + 1 - id, 0) - pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) - out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) - return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) # override of QuantWBIOL method, called by QuantWBIOL.forward_impl def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - if self.is_same_padded_strided: - return self.conv3d_same_zeros_pad_stride(x, quant_weight, quant_bias) - else: - return self._conv_forward(x, quant_weight, quant_bias) + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) From cf86c8b1f5de579c762806944511fce87b066309 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 24 Jan 2024 14:16:34 +0000 Subject: [PATCH 08/37] formatting fixes --- src/brevitas/nn/quant_conv.py | 7 ++++--- src/brevitas/nn/quant_convtranspose.py | 8 +++++--- tests/brevitas/nn/test_conv3d.py | 18 ++++++++++-------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 3d0b9e904..2ae764f5c 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -171,8 +171,9 @@ def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width + class QuantConv3d(QuantWBIOL, Conv3d): - + def __init__( self, in_channels: int, @@ -232,7 +233,7 @@ def channelwise_separable(self) -> bool: # if groups == in_channels that means each channel is convolved with its own set of filters return self.groups == self.channels - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) @@ -247,4 +248,4 @@ def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width \ No newline at end of file + return max_output_bit_width diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index ed9d778d9..caf91d002 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -221,7 +221,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width + class QuantConvTranspose3d(QuantWBIOL, ConvTranspose3d): + def __init__( self, in_channels: int, @@ -283,7 +285,7 @@ def forward(self, output_size=None) -> Union[Tensor, QuantTensor]: self._output_size = output_size # cache the value temporarily return self.forward_impl(input) - + def compute_output_padding(self, inp, output_size): if torch_version >= version.parse('1.12'): return self._output_padding( @@ -291,7 +293,7 @@ def compute_output_padding(self, inp, output_size): else: return self._output_padding( inp, output_size, self.stride, self.padding, self.kernel_size) - + def conv_transpose3d_zeros_pad( self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding): out = conv_transpose3d( @@ -316,4 +318,4 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): overlapping_sums *= max(round(self.kernel_size[2] / self.stride[2]), 1) max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width \ No newline at end of file + return max_output_bit_width diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py index 3dc461177..fdd35524f 100644 --- a/tests/brevitas/nn/test_conv3d.py +++ b/tests/brevitas/nn/test_conv3d.py @@ -11,17 +11,19 @@ OUTPUT_CHANNELS = 10 INPUT_CHANNELS = 5 -KERNEL_SIZE = (3,3,3) +KERNEL_SIZE = (3, 3, 3) WEIGHT_BIT_WIDTH = 5 + class TestQuantConv3d: - + def test_module_init(self): - mod = QuantConv3d(out_channels=OUTPUT_CHANNELS, - in_channels=INPUT_CHANNELS, - kernel_size=KERNEL_SIZE, - bias=False) - + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False) + def test_fp_quant_module(self): float_mod = Conv3d( out_channels=OUTPUT_CHANNELS, @@ -98,4 +100,4 @@ def test_internally_scaled_int_bias_after_bn_merge(self): bn = BatchNorm3d(OUTPUT_CHANNELS) merge_bn(mod, bn) inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) - mod(inp) \ No newline at end of file + mod(inp) From c62e07bb697f33fca86aedfcfa5403e02ef01966 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 29 Jan 2024 14:07:45 +0000 Subject: [PATCH 09/37] Revert "removed is_same_padded_strided and its accompanying function as it is specifically for a mode that pytorch doesn't support" This reverts commit c568895d170c6a485814955e49e5b125a5ec7a7d. --- src/brevitas/nn/quant_conv.py | 74 +++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 2ae764f5c..ae292bdac 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -45,6 +45,12 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False Conv1d.__init__( self, in_channels=in_channels, @@ -66,6 +72,7 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) + self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -82,11 +89,25 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels + def conv1d_same_zeros_pad_stride(self, x, weight, bias): + ih = x.size()[-1] + kh = weight.size()[-1] + sh = self.stride[0] + oh = math.ceil(ih / sh) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + if pad_h > 0: + x = F.pad(x, [pad_h // 2, pad_h - pad_h // 2]) + out = F.conv1d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + return out + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - return self._conv_forward(x, quant_weight, quant_bias) + if self.is_same_padded_strided: + return self.conv1d_same_zeros_pad_stride(x, quant_weight, quant_bias) + else: + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) @@ -118,6 +139,12 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False Conv2d.__init__( self, in_channels=in_channels, @@ -139,6 +166,7 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) + self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -156,11 +184,26 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels + def conv2d_same_zeros_pad_stride(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + return out + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - return self._conv_forward(x, quant_weight, quant_bias) + if self.is_same_padded_strided: + return self.conv2d_same_zeros_pad_stride(x, quant_weight, quant_bias) + else: + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) @@ -193,6 +236,12 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False Conv3d.__init__( self, in_channels=in_channels, @@ -214,6 +263,7 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) + self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -233,13 +283,29 @@ def channelwise_separable(self) -> bool: # if groups == in_channels that means each channel is convolved with its own set of filters return self.groups == self.channels - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + def conv3d_same_zeros_pad_stride(self, x, weight, bias): + id, ih, iw = x.size()[-3:] + kd, kh, kw = weight.size()[-3:] + sd, sh, sw = self.stride + od, oh, ow = math.ceil(id / sd), math.ceil(ih / sh), math.ceil(iw / sw) + pad_d = max((od - 1) * self.stride[0] + (kd - 1) * self.dilation[0] + 1 - id, 0) + pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) + out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + return out + + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) # override of QuantWBIOL method, called by QuantWBIOL.forward_impl def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - return self._conv_forward(x, quant_weight, quant_bias) + if self.is_same_padded_strided: + return self.conv3d_same_zeros_pad_stride(x, quant_weight, quant_bias) + else: + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) From ff281b814d1f2e5ef3e4d671de411b6d5f9b170a Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:14:04 +0000 Subject: [PATCH 10/37] updated references to QuantConv and QuantConvTranpose thoughout code to include QuantConv3d and QuantConvTranspose3d --- src/brevitas/export/onnx/handler.py | 32 +++++++++++++++++++ .../standard/qoperator/handler/parameter.py | 10 ++++-- .../export/onnx/standard/qoperator/manager.py | 2 ++ .../torch/qoperator/handler/parameter.py | 11 ++++++- .../export/torch/qoperator/manager.py | 2 ++ src/brevitas/graph/fixed_point.py | 11 +++++-- src/brevitas/graph/gpfq.py | 8 +++-- src/brevitas/graph/gptq.py | 8 +++-- src/brevitas/graph/gpxq.py | 11 +++++-- src/brevitas/graph/per_input.py | 6 ++-- src/brevitas/graph/quantize.py | 26 +++++++++++++++ src/brevitas/graph/target/flexml.py | 12 +++++++ src/brevitas/graph/utils.py | 3 +- src/brevitas/nn/quant_conv.py | 12 +++++-- tests/brevitas/export/quant_module_fixture.py | 10 +++++- tests/brevitas/nn/nn_quantizers_fixture.py | 6 +++- tests/brevitas/nn/test_a2q.py | 6 ++++ tests/brevitas/nn/test_wbiol.py | 10 +++++- tests/brevitas_ort/common.py | 10 +++++- tests/brevitas_ort/test_quant_module.py | 3 +- 20 files changed, 177 insertions(+), 22 deletions(-) diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py index f856eaeee..9f657f10c 100644 --- a/src/brevitas/export/onnx/handler.py +++ b/src/brevitas/export/onnx/handler.py @@ -88,6 +88,38 @@ def kernel_shape(module): return list(module.kernel_size) +class Kernel3dApplHandlerMixin(ABC): + + @staticmethod + def padding(module): + if isinstance(module.padding, int): + padding = [module.padding] * 6 + else: + padding = list(module.padding) + list(module.padding) + return padding + + @staticmethod + def stride(module): + if isinstance(module.stride, int): + return [module.stride] * 4 + else: + return list(module.stride) + + @staticmethod + def dilation(module): + if isinstance(module.dilation, int): + return [module.dilation] * 4 + else: + return list(module.dilation) + + @staticmethod + def kernel_shape(module): + if isinstance(module.kernel_size, int): + return [module.kernel_size] * 4 + else: + return list(module.kernel_size) + + class ONNXBaseHandler(BaseHandler, ABC): def __init__(self): diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index c4c552c65..6540b6f2d 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -8,11 +8,13 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin +from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin from brevitas.export.onnx.standard.function import DequantizeLinearFn from brevitas.export.onnx.standard.function import IntClipFn from brevitas.export.onnx.standard.function import QuantizeLinearFn from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -56,7 +58,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) - def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]): + def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): self.validate(module) op_symbolic_kwargs = self.op_symbolic_kwargs(module) @@ -106,7 +108,7 @@ def output_symbolic_execution(self, out: Tensor): class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): - def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): + def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': self.quant_input_zero_point(module), @@ -131,6 +133,10 @@ def op_symbolic_execution(self, inp: Tensor): return out +class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin): + handled_layer = QuantConv3d + + class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin): handled_layer = QuantConv2d diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 4c45df04c..174804407 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -26,6 +26,7 @@ from .handler.base import StdQOpONNXQuantLayerHandler from .handler.parameter import StdQOpONNXQuantConv1dHandler from .handler.parameter import StdQOpONNXQuantConv2dHandler +from .handler.parameter import StdQOpONNXQuantConv3dHandler from .handler.parameter import StdQOpONNXQuantLinearHandler from .handler.pool import StdQOpONNXQuantMaxPool1d from .handler.pool import StdQOpONNXQuantMaxPool2d @@ -48,6 +49,7 @@ class StdQOpONNXManager(StdONNXBaseManager): handlers = [ StdQOpONNXQuantConv1dHandler, StdQOpONNXQuantConv2dHandler, + StdQOpONNXQuantConv3dHandler, StdQOpONNXQuantLinearHandler, StdQOpONNXQuantReLUHandler, StdQOpONNXQuantHardTanhHandler, diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index fa110a84c..802a5a053 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -10,6 +10,7 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -93,7 +94,7 @@ def explicit_output_dtype(cls): return True @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]): + def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): return { 'bias': cls.prepare_bias(module), 'stride': module.stride, @@ -119,6 +120,14 @@ def prepare_qf(cls, module: QuantConv2d): return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module) +class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler): + handled_layer = QuantConv3d + + @classmethod + def prepare_qf(cls, module: QuantConv3d): + return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module) + + class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler): handled_layer = QuantLinear diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index aea97957c..7567e6770 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -19,6 +19,7 @@ from .handler.act import PytorchQuantReLUHandler from .handler.parameter import PytorchQuantConv1dHandler from .handler.parameter import PytorchQuantConv2dHandler +from .handler.parameter import PytorchQuantConv3dHandler from .handler.parameter import PytorchQuantLinearHandler from .handler.pool import PytorchQuantMaxPool1d from .handler.pool import PytorchQuantMaxPool2d @@ -35,6 +36,7 @@ class TorchQOpManager(BaseManager): PytorchQuantReLUHandler, PytorchQuantConv1dHandler, PytorchQuantConv2dHandler, + PytorchQuantConv3dHandler, PytorchQuantLinearHandler] @classmethod diff --git a/src/brevitas/graph/fixed_point.py b/src/brevitas/graph/fixed_point.py index afbd32e67..fead31ba2 100644 --- a/src/brevitas/graph/fixed_point.py +++ b/src/brevitas/graph/fixed_point.py @@ -33,11 +33,14 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform): nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, + nn.ConvTranspose3d, qnn.QuantLinear, qnn.QuantConv1d, qnn.QuantConv2d, + qnn.QuantConv3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d) + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES): super(MoveSplitBatchNormBeforeCat, self).__init__() @@ -93,8 +96,10 @@ class MergeBatchNorm(UntilFixedPointGraphTransform): nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias, nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d), (qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d), - (qnn.QuantConvTranspose1d, - nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d)) + (qnn.QuantConv3d, + nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d), + (qnn.QuantConvTranspose2d, + nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d)) def __init__(self, patterns=DEFAULT_PATTERNS): super(MergeBatchNorm, self).__init__() diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 2d312a549..e95839b62 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -146,7 +146,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -210,7 +212,9 @@ def single_layer_update(self): dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index b10943f1b..0f0d9762e 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -153,7 +153,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -198,7 +200,9 @@ def single_layer_update(self, percdamp=.01): dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 1279950a8..0fefbb094 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -16,7 +16,12 @@ from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( - qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) + qnn.QuantConv3d, + qnn.QuantConv2d, + qnn.QuantConv1d, + qnn.QuantConvTranspose1d, + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) class StopFwdException(Exception): @@ -152,7 +157,9 @@ def __init__( # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) self.groups = self.layer.groups diff --git a/src/brevitas/graph/per_input.py b/src/brevitas/graph/per_input.py index 37f5f6eac..908c73a2d 100644 --- a/src/brevitas/graph/per_input.py +++ b/src/brevitas/graph/per_input.py @@ -10,6 +10,7 @@ from brevitas.graph.utils import replace_module from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from .base import PerInputModuleToModuleByHook @@ -93,9 +94,10 @@ def replace_modules(self, model): dw_conv = QuantConv1d(**kwargs) elif isinstance(avgpool, nn.AvgPool2d): dw_conv = QuantConv2d(**kwargs) + elif isinstance(avgpool, nn.AvgPool3d): + dw_conv = QuantConv3d(**kwargs) else: - assert isinstance(avgpool, nn.AvgPool3d) - raise RuntimeError("QuantConv3d not supported yet.") + raise RuntimeError("Unsupported operation.") kernel_value = 1. / reduce(mul, dw_conv.kernel_size) dw_conv.register_parameter( 'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value))) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 63143c4e5..1817c04ed 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -81,6 +81,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -93,6 +99,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.Linear: ( qnn.QuantLinear, { @@ -150,6 +162,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -164,6 +183,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.Linear: ( qnn.QuantLinear, { diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index c10f688e3..3e73ce02a 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,6 +45,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -57,6 +63,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index 2ddfc9b2c..ed00de3eb 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -31,7 +31,8 @@ nn.ConvTranspose2d, nn.ConvTranspose3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d] + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d] def module_class_name(m: torch.nn.Module): diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index ae292bdac..d5148e56f 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -292,11 +292,19 @@ def conv3d_same_zeros_pad_stride(self, x, weight, bias): pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) + x = F.pad( + x, + [ + pad_w // 2, + pad_w - pad_w // 2, + pad_h // 2, + pad_h - pad_h // 2, + pad_d // 2, + pad_d - pad_d // 2]) out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 8bfac6fd2..25db72984 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -9,8 +9,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn import TruncAvgPool2d @@ -53,7 +55,13 @@ 'bias_external_scale': (Int32Bias,), 'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 875b34afa..5b2db000a 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -12,8 +12,10 @@ from brevitas import torch_version from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn.quant_mha import QuantMultiheadAttention @@ -86,8 +88,10 @@ QuantLinear, QuantConv1d, QuantConv2d, + QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d,] + QuantConvTranspose2d, + QuantConvTranspose3d] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index 2abcf9ef2..4940ec990 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -60,12 +60,18 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs[ 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_size, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_size, kernel_size) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3, 4)) elif kwargs[ 'model_type'] == 'QuantConvTranspose1d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2)) elif kwargs[ 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 3577d25a1..fee8b4cfa 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -7,8 +7,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear from brevitas.nn import QuantScaleBias from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -25,7 +27,13 @@ IN_CH = 5 KERNEL_SIZE = 3 -QUANT_CONV_VARIANTS = [QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] +QUANT_CONV_VARIANTS = [ + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index c05fd59b9..ddda4c0e5 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -14,8 +14,10 @@ from brevitas.export import export_qonnx from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn import QuantLSTM @@ -69,7 +71,13 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index a4f7e7f5c..e8e1272cb 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -28,7 +28,8 @@ def test_ort_wbiol(model, export_type, current_cases): -2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc. quantizer = case_id.split('-')[-6] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', + 'QuantConvTranspose3d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') From a02aa23e795650c287a21335d2c22e88bc227e5c Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:06:45 +0000 Subject: [PATCH 11/37] Revert "updated references to QuantConv and QuantConvTranpose thoughout code to include QuantConv3d and QuantConvTranspose3d" This reverts commit ff281b814d1f2e5ef3e4d671de411b6d5f9b170a. --- src/brevitas/export/onnx/handler.py | 32 ------------------- .../standard/qoperator/handler/parameter.py | 10 ++---- .../export/onnx/standard/qoperator/manager.py | 2 -- .../torch/qoperator/handler/parameter.py | 11 +------ .../export/torch/qoperator/manager.py | 2 -- src/brevitas/graph/fixed_point.py | 11 ++----- src/brevitas/graph/gpfq.py | 8 ++--- src/brevitas/graph/gptq.py | 8 ++--- src/brevitas/graph/gpxq.py | 11 ++----- src/brevitas/graph/per_input.py | 6 ++-- src/brevitas/graph/quantize.py | 26 --------------- src/brevitas/graph/target/flexml.py | 12 ------- src/brevitas/graph/utils.py | 3 +- src/brevitas/nn/quant_conv.py | 12 ++----- tests/brevitas/export/quant_module_fixture.py | 10 +----- tests/brevitas/nn/nn_quantizers_fixture.py | 6 +--- tests/brevitas/nn/test_a2q.py | 6 ---- tests/brevitas/nn/test_wbiol.py | 10 +----- tests/brevitas_ort/common.py | 10 +----- tests/brevitas_ort/test_quant_module.py | 3 +- 20 files changed, 22 insertions(+), 177 deletions(-) diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py index 9f657f10c..f856eaeee 100644 --- a/src/brevitas/export/onnx/handler.py +++ b/src/brevitas/export/onnx/handler.py @@ -88,38 +88,6 @@ def kernel_shape(module): return list(module.kernel_size) -class Kernel3dApplHandlerMixin(ABC): - - @staticmethod - def padding(module): - if isinstance(module.padding, int): - padding = [module.padding] * 6 - else: - padding = list(module.padding) + list(module.padding) - return padding - - @staticmethod - def stride(module): - if isinstance(module.stride, int): - return [module.stride] * 4 - else: - return list(module.stride) - - @staticmethod - def dilation(module): - if isinstance(module.dilation, int): - return [module.dilation] * 4 - else: - return list(module.dilation) - - @staticmethod - def kernel_shape(module): - if isinstance(module.kernel_size, int): - return [module.kernel_size] * 4 - else: - return list(module.kernel_size) - - class ONNXBaseHandler(BaseHandler, ABC): def __init__(self): diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index 6540b6f2d..c4c552c65 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -8,13 +8,11 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin -from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin from brevitas.export.onnx.standard.function import DequantizeLinearFn from brevitas.export.onnx.standard.function import IntClipFn from brevitas.export.onnx.standard.function import QuantizeLinearFn from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -58,7 +56,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) - def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): + def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]): self.validate(module) op_symbolic_kwargs = self.op_symbolic_kwargs(module) @@ -108,7 +106,7 @@ def output_symbolic_execution(self, out: Tensor): class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): - def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): + def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': self.quant_input_zero_point(module), @@ -133,10 +131,6 @@ def op_symbolic_execution(self, inp: Tensor): return out -class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin): - handled_layer = QuantConv3d - - class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin): handled_layer = QuantConv2d diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 174804407..4c45df04c 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -26,7 +26,6 @@ from .handler.base import StdQOpONNXQuantLayerHandler from .handler.parameter import StdQOpONNXQuantConv1dHandler from .handler.parameter import StdQOpONNXQuantConv2dHandler -from .handler.parameter import StdQOpONNXQuantConv3dHandler from .handler.parameter import StdQOpONNXQuantLinearHandler from .handler.pool import StdQOpONNXQuantMaxPool1d from .handler.pool import StdQOpONNXQuantMaxPool2d @@ -49,7 +48,6 @@ class StdQOpONNXManager(StdONNXBaseManager): handlers = [ StdQOpONNXQuantConv1dHandler, StdQOpONNXQuantConv2dHandler, - StdQOpONNXQuantConv3dHandler, StdQOpONNXQuantLinearHandler, StdQOpONNXQuantReLUHandler, StdQOpONNXQuantHardTanhHandler, diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index 802a5a053..fa110a84c 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -10,7 +10,6 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -94,7 +93,7 @@ def explicit_output_dtype(cls): return True @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): + def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]): return { 'bias': cls.prepare_bias(module), 'stride': module.stride, @@ -120,14 +119,6 @@ def prepare_qf(cls, module: QuantConv2d): return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module) -class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler): - handled_layer = QuantConv3d - - @classmethod - def prepare_qf(cls, module: QuantConv3d): - return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module) - - class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler): handled_layer = QuantLinear diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index 7567e6770..aea97957c 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -19,7 +19,6 @@ from .handler.act import PytorchQuantReLUHandler from .handler.parameter import PytorchQuantConv1dHandler from .handler.parameter import PytorchQuantConv2dHandler -from .handler.parameter import PytorchQuantConv3dHandler from .handler.parameter import PytorchQuantLinearHandler from .handler.pool import PytorchQuantMaxPool1d from .handler.pool import PytorchQuantMaxPool2d @@ -36,7 +35,6 @@ class TorchQOpManager(BaseManager): PytorchQuantReLUHandler, PytorchQuantConv1dHandler, PytorchQuantConv2dHandler, - PytorchQuantConv3dHandler, PytorchQuantLinearHandler] @classmethod diff --git a/src/brevitas/graph/fixed_point.py b/src/brevitas/graph/fixed_point.py index fead31ba2..afbd32e67 100644 --- a/src/brevitas/graph/fixed_point.py +++ b/src/brevitas/graph/fixed_point.py @@ -33,14 +33,11 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform): nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, - nn.ConvTranspose3d, qnn.QuantLinear, qnn.QuantConv1d, qnn.QuantConv2d, - qnn.QuantConv3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d, - qnn.QuantConvTranspose3d) + qnn.QuantConvTranspose2d) def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES): super(MoveSplitBatchNormBeforeCat, self).__init__() @@ -96,10 +93,8 @@ class MergeBatchNorm(UntilFixedPointGraphTransform): nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias, nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d), (qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d), - (qnn.QuantConv3d, - nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d), - (qnn.QuantConvTranspose2d, - nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d)) + (qnn.QuantConvTranspose1d, + nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d)) def __init__(self, patterns=DEFAULT_PATTERNS): super(MergeBatchNorm, self).__init__() diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 19efbf9e8..cad5d9043 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -156,9 +156,7 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -222,9 +220,7 @@ def single_layer_update(self): dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index f1261f34b..56171ac6f 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -148,9 +148,7 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -195,9 +193,7 @@ def single_layer_update(self, percdamp=.01): dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 870bf6d69..ddeef1c53 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -18,12 +18,7 @@ from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( - qnn.QuantConv3d, - qnn.QuantConv2d, - qnn.QuantConv1d, - qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d, - qnn.QuantConvTranspose3d) + qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) class StopFwdException(Exception): @@ -159,9 +154,7 @@ def __init__( # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) self.groups = self.layer.groups diff --git a/src/brevitas/graph/per_input.py b/src/brevitas/graph/per_input.py index 908c73a2d..37f5f6eac 100644 --- a/src/brevitas/graph/per_input.py +++ b/src/brevitas/graph/per_input.py @@ -10,7 +10,6 @@ from brevitas.graph.utils import replace_module from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from .base import PerInputModuleToModuleByHook @@ -94,10 +93,9 @@ def replace_modules(self, model): dw_conv = QuantConv1d(**kwargs) elif isinstance(avgpool, nn.AvgPool2d): dw_conv = QuantConv2d(**kwargs) - elif isinstance(avgpool, nn.AvgPool3d): - dw_conv = QuantConv3d(**kwargs) else: - raise RuntimeError("Unsupported operation.") + assert isinstance(avgpool, nn.AvgPool3d) + raise RuntimeError("QuantConv3d not supported yet.") kernel_value = 1. / reduce(mul, dw_conv.kernel_size) dw_conv.register_parameter( 'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value))) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 1817c04ed..63143c4e5 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -81,12 +81,6 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), - nn.Conv3d: ( - qnn.QuantConv3d, - { - 'weight_quant': Int8WeightPerTensorFloat, - 'bias_quant': Int32Bias, - 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -99,12 +93,6 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), - nn.ConvTranspose3d: ( - qnn.QuantConvTranspose3d, - { - 'weight_quant': Int8WeightPerTensorFloat, - 'bias_quant': Int32Bias, - 'return_quant_tensor': True}), nn.Linear: ( qnn.QuantLinear, { @@ -162,13 +150,6 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), - nn.Conv3d: ( - qnn.QuantConv3d, - { - 'input_quant': Int8ActPerTensorFloat, - 'weight_quant': Int8WeightPerTensorFloat, - 'bias_quant': Int32Bias, - 'return_quant_tensor': False}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -183,13 +164,6 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), - nn.ConvTranspose3d: ( - qnn.QuantConvTranspose3d, - { - 'input_quant': Int8ActPerTensorFloat, - 'weight_quant': Int8WeightPerTensorFloat, - 'bias_quant': Int32Bias, - 'return_quant_tensor': False}), nn.Linear: ( qnn.QuantLinear, { diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index 35b9a4d14..9aedd337c 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,12 +45,6 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), - nn.Conv3d: ( - qnn.QuantConv3d, - { - 'weight_quant': Int8WeightPerTensorFixedPoint, - 'bias_quant': Int16Bias, - 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -63,12 +57,6 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), - nn.ConvTranspose3d: ( - qnn.QuantConvTranspose3d, - { - 'weight_quant': Int8WeightPerTensorFixedPoint, - 'bias_quant': Int16Bias, - 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index ed00de3eb..2ddfc9b2c 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -31,8 +31,7 @@ nn.ConvTranspose2d, nn.ConvTranspose3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d, - qnn.QuantConvTranspose3d] + qnn.QuantConvTranspose2d] def module_class_name(m: torch.nn.Module): diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index d5148e56f..ae292bdac 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -292,19 +292,11 @@ def conv3d_same_zeros_pad_stride(self, x, weight, bias): pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: - x = F.pad( - x, - [ - pad_w // 2, - pad_w - pad_w // 2, - pad_h // 2, - pad_h - pad_h // 2, - pad_d // 2, - pad_d - pad_d // 2]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 25db72984..8bfac6fd2 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -9,10 +9,8 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d -from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn import TruncAvgPool2d @@ -55,13 +53,7 @@ 'bias_external_scale': (Int32Bias,), 'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)} QUANT_WBIOL_IMPL = [ - QuantLinear, - QuantConv1d, - QuantConv2d, - QuantConv3d, - QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index f099e3b91..538e836e8 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -12,10 +12,8 @@ from brevitas import torch_version from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d -from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn.quant_mha import QuantMultiheadAttention @@ -93,10 +91,8 @@ QuantLinear, QuantConv1d, QuantConv2d, - QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantConvTranspose2d,] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index d34bd6683..8c34f390c 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -102,18 +102,12 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs[ 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_size, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) - elif kwargs[ - 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_size, kernel_size) - quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3, 4)) elif kwargs[ 'model_type'] == 'QuantConvTranspose1d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2)) elif kwargs[ 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) - elif kwargs[ - 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) - quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index fee8b4cfa..3577d25a1 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -7,10 +7,8 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d -from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear from brevitas.nn import QuantScaleBias from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -27,13 +25,7 @@ IN_CH = 5 KERNEL_SIZE = 3 -QUANT_CONV_VARIANTS = [ - QuantConv1d, - QuantConv2d, - QuantConv3d, - QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] +QUANT_CONV_VARIANTS = [QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 415d79a6b..e97f21fca 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -14,10 +14,8 @@ from brevitas.export import export_qonnx from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d -from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint @@ -72,13 +70,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} QUANT_WBIOL_IMPL = [ - QuantLinear, - QuantConv1d, - QuantConv2d, - QuantConv3d, - QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 619bcc93c..0b1277686 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,8 +30,7 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', - 'QuantConvTranspose3d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') From 10b331ec70acd06a8092a10ca69445aa3fa666d9 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:27:43 +0000 Subject: [PATCH 12/37] pre-commit hook changes --- src/brevitas/nn/quant_conv.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index ae292bdac..d5148e56f 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -292,11 +292,19 @@ def conv3d_same_zeros_pad_stride(self, x, weight, bias): pad_h = max((oh - 1) * self.stride[1] + (kh - 1) * self.dilation[1] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[2] + (kw - 1) * self.dilation[2] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) + x = F.pad( + x, + [ + pad_w // 2, + pad_w - pad_w // 2, + pad_h // 2, + pad_h - pad_h // 2, + pad_d // 2, + pad_d - pad_d // 2]) out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) return out - def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor,QuantTensor]: + def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # calls QuantWBIOL.forward_impl and eventually inner_forward_impl below return self.forward_impl(input) From 57d8d8bfa3f6f798236176db2ce68810df5c1bd4 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:14:04 +0000 Subject: [PATCH 13/37] updated references to QuantConv and QuantConvTranpose thoughout code to include QuantConv3d and QuantConvTranspose3d --- src/brevitas/export/onnx/handler.py | 32 +++++++++++++++++++ .../standard/qoperator/handler/parameter.py | 10 ++++-- .../export/onnx/standard/qoperator/manager.py | 2 ++ .../torch/qoperator/handler/parameter.py | 11 ++++++- .../export/torch/qoperator/manager.py | 2 ++ src/brevitas/graph/fixed_point.py | 11 +++++-- src/brevitas/graph/gpfq.py | 8 +++-- src/brevitas/graph/gptq.py | 8 +++-- src/brevitas/graph/gpxq.py | 11 +++++-- src/brevitas/graph/per_input.py | 6 ++-- src/brevitas/graph/quantize.py | 26 +++++++++++++++ src/brevitas/graph/target/flexml.py | 12 +++++++ src/brevitas/graph/utils.py | 3 +- tests/brevitas/export/quant_module_fixture.py | 10 +++++- tests/brevitas/nn/nn_quantizers_fixture.py | 6 +++- tests/brevitas/nn/test_a2q.py | 6 ++++ tests/brevitas/nn/test_wbiol.py | 10 +++++- tests/brevitas_ort/common.py | 10 +++++- tests/brevitas_ort/test_quant_module.py | 3 +- 19 files changed, 167 insertions(+), 20 deletions(-) diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py index f856eaeee..9f657f10c 100644 --- a/src/brevitas/export/onnx/handler.py +++ b/src/brevitas/export/onnx/handler.py @@ -88,6 +88,38 @@ def kernel_shape(module): return list(module.kernel_size) +class Kernel3dApplHandlerMixin(ABC): + + @staticmethod + def padding(module): + if isinstance(module.padding, int): + padding = [module.padding] * 6 + else: + padding = list(module.padding) + list(module.padding) + return padding + + @staticmethod + def stride(module): + if isinstance(module.stride, int): + return [module.stride] * 4 + else: + return list(module.stride) + + @staticmethod + def dilation(module): + if isinstance(module.dilation, int): + return [module.dilation] * 4 + else: + return list(module.dilation) + + @staticmethod + def kernel_shape(module): + if isinstance(module.kernel_size, int): + return [module.kernel_size] * 4 + else: + return list(module.kernel_size) + + class ONNXBaseHandler(BaseHandler, ABC): def __init__(self): diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index c4c552c65..6540b6f2d 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -8,11 +8,13 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin +from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin from brevitas.export.onnx.standard.function import DequantizeLinearFn from brevitas.export.onnx.standard.function import IntClipFn from brevitas.export.onnx.standard.function import QuantizeLinearFn from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -56,7 +58,7 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) - def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]): + def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): self.validate(module) op_symbolic_kwargs = self.op_symbolic_kwargs(module) @@ -106,7 +108,7 @@ def output_symbolic_execution(self, out: Tensor): class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): - def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]): + def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): conv_symbolic_kwargs = { 'input_scale': module.quant_input_scale(), 'input_zero_point': self.quant_input_zero_point(module), @@ -131,6 +133,10 @@ def op_symbolic_execution(self, inp: Tensor): return out +class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin): + handled_layer = QuantConv3d + + class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin): handled_layer = QuantConv2d diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 4c45df04c..174804407 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -26,6 +26,7 @@ from .handler.base import StdQOpONNXQuantLayerHandler from .handler.parameter import StdQOpONNXQuantConv1dHandler from .handler.parameter import StdQOpONNXQuantConv2dHandler +from .handler.parameter import StdQOpONNXQuantConv3dHandler from .handler.parameter import StdQOpONNXQuantLinearHandler from .handler.pool import StdQOpONNXQuantMaxPool1d from .handler.pool import StdQOpONNXQuantMaxPool2d @@ -48,6 +49,7 @@ class StdQOpONNXManager(StdONNXBaseManager): handlers = [ StdQOpONNXQuantConv1dHandler, StdQOpONNXQuantConv2dHandler, + StdQOpONNXQuantConv3dHandler, StdQOpONNXQuantLinearHandler, StdQOpONNXQuantReLUHandler, StdQOpONNXQuantHardTanhHandler, diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index fa110a84c..802a5a053 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -10,6 +10,7 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -93,7 +94,7 @@ def explicit_output_dtype(cls): return True @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d]): + def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): return { 'bias': cls.prepare_bias(module), 'stride': module.stride, @@ -119,6 +120,14 @@ def prepare_qf(cls, module: QuantConv2d): return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module) +class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler): + handled_layer = QuantConv3d + + @classmethod + def prepare_qf(cls, module: QuantConv3d): + return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module) + + class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler): handled_layer = QuantLinear diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index aea97957c..7567e6770 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -19,6 +19,7 @@ from .handler.act import PytorchQuantReLUHandler from .handler.parameter import PytorchQuantConv1dHandler from .handler.parameter import PytorchQuantConv2dHandler +from .handler.parameter import PytorchQuantConv3dHandler from .handler.parameter import PytorchQuantLinearHandler from .handler.pool import PytorchQuantMaxPool1d from .handler.pool import PytorchQuantMaxPool2d @@ -35,6 +36,7 @@ class TorchQOpManager(BaseManager): PytorchQuantReLUHandler, PytorchQuantConv1dHandler, PytorchQuantConv2dHandler, + PytorchQuantConv3dHandler, PytorchQuantLinearHandler] @classmethod diff --git a/src/brevitas/graph/fixed_point.py b/src/brevitas/graph/fixed_point.py index afbd32e67..fead31ba2 100644 --- a/src/brevitas/graph/fixed_point.py +++ b/src/brevitas/graph/fixed_point.py @@ -33,11 +33,14 @@ class MoveSplitBatchNormBeforeCat(UntilFixedPointGraphTransform): nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, + nn.ConvTranspose3d, qnn.QuantLinear, qnn.QuantConv1d, qnn.QuantConv2d, + qnn.QuantConv3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d) + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) def __init__(self, before_modules_types=DEFAULT_BEFORE_MODULES_TYPES): super(MoveSplitBatchNormBeforeCat, self).__init__() @@ -93,8 +96,10 @@ class MergeBatchNorm(UntilFixedPointGraphTransform): nn.BatchNorm1d), (qnn.BatchNorm2dToQuantScaleBias, nn.BatchNorm2d), (qnn.QuantLinear, nn.BatchNorm1d), (qnn.QuantConv1d, nn.BatchNorm1d), (qnn.QuantConv2d, nn.BatchNorm2d), - (qnn.QuantConvTranspose1d, - nn.BatchNorm1d), (qnn.QuantConvTranspose2d, nn.BatchNorm2d)) + (qnn.QuantConv3d, + nn.BatchNorm3d), (qnn.QuantConvTranspose1d, nn.BatchNorm1d), + (qnn.QuantConvTranspose2d, + nn.BatchNorm2d), (qnn.QuantConvTranspose3d, nn.BatchNorm3d)) def __init__(self, patterns=DEFAULT_PATTERNS): super(MergeBatchNorm, self).__init__() diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index cad5d9043..19efbf9e8 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -156,7 +156,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -220,7 +222,9 @@ def single_layer_update(self): dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 56171ac6f..f1261f34b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -148,7 +148,9 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -193,7 +195,9 @@ def single_layer_update(self, percdamp=.01): dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index ddeef1c53..870bf6d69 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -18,7 +18,12 @@ from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( - qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) + qnn.QuantConv3d, + qnn.QuantConv2d, + qnn.QuantConv1d, + qnn.QuantConvTranspose1d, + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d) class StopFwdException(Exception): @@ -154,7 +159,9 @@ def __init__( # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) self.groups = self.layer.groups diff --git a/src/brevitas/graph/per_input.py b/src/brevitas/graph/per_input.py index 37f5f6eac..908c73a2d 100644 --- a/src/brevitas/graph/per_input.py +++ b/src/brevitas/graph/per_input.py @@ -10,6 +10,7 @@ from brevitas.graph.utils import replace_module from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from .base import PerInputModuleToModuleByHook @@ -93,9 +94,10 @@ def replace_modules(self, model): dw_conv = QuantConv1d(**kwargs) elif isinstance(avgpool, nn.AvgPool2d): dw_conv = QuantConv2d(**kwargs) + elif isinstance(avgpool, nn.AvgPool3d): + dw_conv = QuantConv3d(**kwargs) else: - assert isinstance(avgpool, nn.AvgPool3d) - raise RuntimeError("QuantConv3d not supported yet.") + raise RuntimeError("Unsupported operation.") kernel_value = 1. / reduce(mul, dw_conv.kernel_size) dw_conv.register_parameter( 'scalar_weight', torch.nn.Parameter(torch.tensor(kernel_value))) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 63143c4e5..1817c04ed 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -81,6 +81,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -93,6 +99,12 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': True}), nn.Linear: ( qnn.QuantLinear, { @@ -150,6 +162,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -164,6 +183,13 @@ 'weight_quant': Int8WeightPerTensorFloat, 'bias_quant': Int32Bias, 'return_quant_tensor': False}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'input_quant': Int8ActPerTensorFloat, + 'weight_quant': Int8WeightPerTensorFloat, + 'bias_quant': Int32Bias, + 'return_quant_tensor': False}), nn.Linear: ( qnn.QuantLinear, { diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index 9aedd337c..35b9a4d14 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,6 +45,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -57,6 +63,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index 2ddfc9b2c..ed00de3eb 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -31,7 +31,8 @@ nn.ConvTranspose2d, nn.ConvTranspose3d, qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d] + qnn.QuantConvTranspose2d, + qnn.QuantConvTranspose3d] def module_class_name(m: torch.nn.Module): diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 8bfac6fd2..25db72984 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -9,8 +9,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn import TruncAvgPool2d @@ -53,7 +55,13 @@ 'bias_external_scale': (Int32Bias,), 'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 538e836e8..f099e3b91 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -12,8 +12,10 @@ from brevitas import torch_version from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.nn.quant_mha import QuantMultiheadAttention @@ -91,8 +93,10 @@ QuantLinear, QuantConv1d, QuantConv2d, + QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d,] + QuantConvTranspose2d, + QuantConvTranspose3d] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index 8c34f390c..d34bd6683 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -102,12 +102,18 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs[ 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_size, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_size, kernel_size) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3, 4)) elif kwargs[ 'model_type'] == 'QuantConvTranspose1d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2)) elif kwargs[ 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) + elif kwargs[ + 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 3577d25a1..fee8b4cfa 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -7,8 +7,10 @@ from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear from brevitas.nn import QuantScaleBias from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -25,7 +27,13 @@ IN_CH = 5 KERNEL_SIZE = 3 -QUANT_CONV_VARIANTS = [QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] +QUANT_CONV_VARIANTS = [ + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index e97f21fca..415d79a6b 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -14,8 +14,10 @@ from brevitas.export import export_qonnx from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint @@ -70,7 +72,13 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConv3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 0b1277686..619bcc93c 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,7 +30,8 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', + 'QuantConvTranspose3d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') From d71762ed780fb99c3e5384c81e0b408d7bd22f44 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:10:36 +0000 Subject: [PATCH 14/37] removing unused import --- tests/brevitas_ort/common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 415d79a6b..0f9335bc0 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -18,7 +18,6 @@ from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d from brevitas.nn import QuantConvTranspose3d -from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint From f6559cceccb4cc70c912817584878307eac799fd Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 31 Jan 2024 10:30:13 +0000 Subject: [PATCH 15/37] added condition for quantconv2d and made default case conv3d --- tests/brevitas_ort/test_quant_module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 619bcc93c..dcbb184a8 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -45,8 +45,10 @@ def test_ort_wbiol(model, export_type, current_cases): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From bab60429a8b4329dd28a44665c044fdcec9de446 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:53:22 +0000 Subject: [PATCH 16/37] disable QuantConvTranspose3d in tests --- tests/brevitas/export/quant_module_fixture.py | 4 ++-- tests/brevitas/nn/nn_quantizers_fixture.py | 4 ++-- tests/brevitas/nn/test_a2q.py | 6 +++--- tests/brevitas/nn/test_wbiol.py | 4 ++-- tests/brevitas_ort/common.py | 4 ++-- tests/brevitas_ort/test_quant_module.py | 11 ++++++----- 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 25db72984..52d75469b 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -60,8 +60,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantConvTranspose2d, #QuantConvTranspose3d, +] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index f099e3b91..ed65061b6 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -95,8 +95,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantConvTranspose2d, #QuantConvTranspose3d, +] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index d34bd6683..362d1e463 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -111,9 +111,9 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs[ 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) - elif kwargs[ - 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) - quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) + #elif kwargs[ + # 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) + # quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index fee8b4cfa..d373484db 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -32,8 +32,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantConvTranspose2d, #QuantConvTranspose3d, +] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 0f9335bc0..d6b705fa9 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -76,8 +76,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, - QuantConvTranspose3d] + QuantConvTranspose2d, #QuantConvTranspose3d, +] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index dcbb184a8..fdd95c04e 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,8 +30,8 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', - 'QuantConvTranspose3d') and export_type == 'qop': + #if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d','QuantConvTranspose3d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') @@ -45,10 +45,11 @@ def test_ort_wbiol(model, export_type, current_cases): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): - in_size = (1, IN_CH, FEATURES, FEATURES) else: - in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + #elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): + in_size = (1, IN_CH, FEATURES, FEATURES) + #else: + # in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From 22fe5c22e228281ac5c268dd749683944b0f547a Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 31 Jan 2024 12:17:19 +0000 Subject: [PATCH 17/37] restored function necessary for QuantConv3d to be tested --- tests/brevitas_ort/test_quant_module.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index fdd95c04e..ced58dbde 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -45,11 +45,10 @@ def test_ort_wbiol(model, export_type, current_cases): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: - #elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) - #else: - # in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From ee01b84c378930ae655e83d580d037085f921a43 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:28:56 +0000 Subject: [PATCH 18/37] restored QuantConvTranspose3d to tests --- tests/brevitas/export/quant_module_fixture.py | 3 ++- tests/brevitas/nn/nn_quantizers_fixture.py | 3 ++- tests/brevitas/nn/test_a2q.py | 6 +++--- tests/brevitas/nn/test_wbiol.py | 3 ++- tests/brevitas_ort/common.py | 3 ++- tests/brevitas_ort/test_quant_module.py | 3 +-- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 52d75469b..fadeab0e0 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -60,7 +60,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, #QuantConvTranspose3d, + QuantConvTranspose2d, + QuantConvTranspose3d, ] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index ed65061b6..7b5d1e16f 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -95,7 +95,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, #QuantConvTranspose3d, + QuantConvTranspose2d, + QuantConvTranspose3d, ] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index 362d1e463..d34bd6683 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -111,9 +111,9 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs[ 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) - #elif kwargs[ - # 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) - # quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) + elif kwargs[ + 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index d373484db..28e6fa5f9 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -32,7 +32,8 @@ QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, #QuantConvTranspose3d, + QuantConvTranspose2d, + QuantConvTranspose3d, ] diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index d6b705fa9..2663a0d5d 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -76,7 +76,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv2d, QuantConv3d, QuantConvTranspose1d, - QuantConvTranspose2d, #QuantConvTranspose3d, + QuantConvTranspose2d, + QuantConvTranspose3d, ] diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index ced58dbde..ab614b896 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,8 +30,7 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - #if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d','QuantConvTranspose3d') and export_type == 'qop': - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', 'QuantConvTranspose3d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') From 04ca06c1020b56f613ed83d428cb539c9bcf42ae Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:34:23 +0000 Subject: [PATCH 19/37] pre-commit changes --- tests/brevitas/export/quant_module_fixture.py | 3 +-- tests/brevitas/nn/nn_quantizers_fixture.py | 3 +-- tests/brevitas/nn/test_wbiol.py | 3 +-- tests/brevitas_ort/common.py | 3 +-- tests/brevitas_ort/test_quant_module.py | 3 ++- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index fadeab0e0..31524729f 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -61,8 +61,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, -] + QuantConvTranspose3d,] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 7b5d1e16f..4d4983ba1 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -96,8 +96,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, -] + QuantConvTranspose3d,] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 28e6fa5f9..58b9a86ca 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -33,8 +33,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, -] + QuantConvTranspose3d,] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 2663a0d5d..4c148e96b 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -77,8 +77,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, -] + QuantConvTranspose3d,] def compute_ort(export_name, np_input): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index ab614b896..dcbb184a8 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -30,7 +30,8 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', 'QuantConvTranspose3d') and export_type == 'qop': + if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', + 'QuantConvTranspose3d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') From 382c50b35724110159572eaa1b05761f720dae0b Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:06:32 +0000 Subject: [PATCH 20/37] fixed missing parts in test code causing test failures --- tests/brevitas/export/quant_module_fixture.py | 3 ++- tests/brevitas/export/test_torch_qcdq.py | 4 +++- tests/brevitas/nn/nn_quantizers_fixture.py | 7 +++++-- tests/brevitas/nn/test_wbiol.py | 3 ++- tests/brevitas_ort/common.py | 3 ++- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 31524729f..9e61c555b 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -61,7 +61,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 4f737f8d7..6019bf417 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -36,8 +36,10 @@ def test_torch_qcdq_wbiol_export( in_size = (1, IN_CH) elif quant_module_impl == QuantConv1d or quant_module_impl == QuantConvTranspose1d: in_size = (1, IN_CH, FEATURES) - else: + elif quant_module_impl == QuantConv2d or quant_module_impl == QuantConvTranspose2d: in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = torch.randn(in_size) quant_module(inp) # Collect scale factors diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 4d4983ba1..95c640de8 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -96,7 +96,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] @@ -161,8 +162,10 @@ def forward(self, x): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) if input_quantized: quant_inp = QuantTensor( diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 58b9a86ca..9df9faa4d 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -33,7 +33,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 4c148e96b..ef46cfcbf 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -77,7 +77,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] def compute_ort(export_name, np_input): From a068004ada8802212798b6ae9a2dd7df71c7b0e6 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:13:16 +0000 Subject: [PATCH 21/37] pre-commit hook changes --- tests/brevitas/export/quant_module_fixture.py | 3 +-- tests/brevitas/nn/nn_quantizers_fixture.py | 3 +-- tests/brevitas/nn/test_wbiol.py | 3 +-- tests/brevitas_ort/common.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 9e61c555b..31524729f 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -61,8 +61,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, - ] + QuantConvTranspose3d,] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 95c640de8..910443dff 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -96,8 +96,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, - ] + QuantConvTranspose3d,] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 9df9faa4d..58b9a86ca 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -33,8 +33,7 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, - ] + QuantConvTranspose3d,] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index ef46cfcbf..4c148e96b 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -77,8 +77,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d, - ] + QuantConvTranspose3d,] def compute_ort(export_name, np_input): From f30e36fa6cc991eaed55045fc20bfe7788aa9de6 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 7 Feb 2024 12:04:53 +0000 Subject: [PATCH 22/37] fixed typo - should be conv3d --- src/brevitas/nn/quant_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index d5148e56f..c05200177 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -301,7 +301,7 @@ def conv3d_same_zeros_pad_stride(self, x, weight, bias): pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2]) - out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation, self.groups) + out = F.conv3d(x, weight, bias, self.stride, 0, self.dilation, self.groups) return out def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: From f35cf4b1d60e68cd414e695650bdee6bd7e0dc63 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 8 Feb 2024 14:34:51 +0000 Subject: [PATCH 23/37] removed quantconv3d from flexml.py as it is unnecessary --- src/brevitas/graph/target/flexml.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index 35b9a4d14..9aedd337c 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,12 +45,6 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), - nn.Conv3d: ( - qnn.QuantConv3d, - { - 'weight_quant': Int8WeightPerTensorFixedPoint, - 'bias_quant': Int16Bias, - 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -63,12 +57,6 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), - nn.ConvTranspose3d: ( - qnn.QuantConvTranspose3d, - { - 'weight_quant': Int8WeightPerTensorFixedPoint, - 'bias_quant': Int16Bias, - 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { From 8e7b647b6a3273d1da27cf8d5c4fb6c9b3e5428d Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 8 Feb 2024 14:44:37 +0000 Subject: [PATCH 24/37] made check for 3d version explicit instead of default case --- tests/brevitas/nn/nn_quantizers_fixture.py | 4 +++- tests/brevitas_ort/test_quant_module.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 910443dff..6f0ddd6e0 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -163,8 +163,10 @@ def forward(self, x): in_size = (1, IN_CH, FEATURES) elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) - else: + elif impl in ('QuantConv3d', 'QuantConvTranspose3d'): in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + else: + raise RuntimeError("Unsupported operation") if input_quantized: quant_inp = QuantTensor( diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index dcbb184a8..fce766d2d 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -47,8 +47,10 @@ def test_ort_wbiol(model, export_type, current_cases): in_size = (1, IN_CH, FEATURES) elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) - else: + elif impl in ('QuantConv3d', 'QuantConvTranspose3d'): in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) + else: + raise RuntimeError("Unsupported operation") inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From 479eb9fd6c6e1585cd00ca5fa2ff1150187a1fcc Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 8 Feb 2024 14:54:59 +0000 Subject: [PATCH 25/37] added tests for conv1d,2d,3d merge batch norm --- tests/brevitas/graph/test_transforms.py | 60 +++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index a4865ac24..fb875b96f 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -50,6 +50,66 @@ def test_rewriter_merge_bn(model_name: str, pretrained: bool): assert is_close +def test_conv1d_merge_bn(): + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + + self.net = nn.Sequential(nn.Conv1d(16, 33, 3, stride=2), nn.BatchNorm1d(33), nn.ReLU()) + + def forward(self, x): + return self.net(x) + + model = TestModel() + graph = symbolic_trace(model) + graph = MergeBatchNorm().apply(graph) + + for m in graph.modules(): + assert not isinstance(m, nn.BatchNorm1d) + + +def test_conv2d_merge_bn(): + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + + self.net = nn.Sequential(nn.Conv2d(16, 33, 3, stride=2), nn.BatchNorm2d(33), nn.ReLU()) + + def forward(self, x): + return self.net(x) + + model = TestModel() + graph = symbolic_trace(model) + graph = MergeBatchNorm().apply(graph) + + for m in graph.modules(): + assert not isinstance(m, nn.BatchNorm2d) + + +def test_conv3d_merge_bn(): + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + + self.net = nn.Sequential(nn.Conv3d(16, 33, 3, stride=2), nn.BatchNorm3d(33), nn.ReLU()) + + def forward(self, x): + return self.net(x) + + model = TestModel() + graph = symbolic_trace(model) + graph = MergeBatchNorm().apply(graph) + + for m in graph.modules(): + assert not isinstance(m, nn.BatchNorm3d) + + def test_rewriter_duplicate_shared_relu(): class TestModel(nn.Module): From 51771bc33d38cde7677f6c3b362cd3e181a83ed9 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:24:34 +0000 Subject: [PATCH 26/37] added tests to check if avgpool is replace with quantconvs and that mergebatchnorm correctly removes batchnorm --- tests/brevitas/graph/test_transforms.py | 88 +++++++++++++++++-------- 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index fb875b96f..f54090280 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -10,12 +10,16 @@ from torchvision import models from brevitas.fx import symbolic_trace +from brevitas.graph import AvgPoolToQuantDepthwiseConv from brevitas.graph import DuplicateSharedStatelessModule from brevitas.graph import FnToModule from brevitas.graph import MeanMethodToAdaptiveAvgPool2d from brevitas.graph import MergeBatchNorm from brevitas.graph import MethodToModule from brevitas.graph.base import ModuleToModuleByInstance +from brevitas.nn import QuantConv1d +from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d SEED = 123456 INPUT_SIZE = (1, 3, 224, 224) @@ -50,64 +54,90 @@ def test_rewriter_merge_bn(model_name: str, pretrained: bool): assert is_close -def test_conv1d_merge_bn(): +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_conv_merge_bn(dims): class TestModel(nn.Module): - def __init__(self): + def __init__(self, dims): super(TestModel, self).__init__() + layers = [] - self.net = nn.Sequential(nn.Conv1d(16, 33, 3, stride=2), nn.BatchNorm1d(33), nn.ReLU()) - - def forward(self, x): - return self.net(x) - - model = TestModel() - graph = symbolic_trace(model) - graph = MergeBatchNorm().apply(graph) - - for m in graph.modules(): - assert not isinstance(m, nn.BatchNorm1d) - + if dims == 1: + layers.append(nn.Conv1d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm1d(33)) + elif dims == 2: + layers.append(nn.Conv2d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm2d(33)) + else: + layers.append(nn.Conv3d(16, 33, 3, stride=2)) + layers.append(nn.BatchNorm3d(33)) -def test_conv2d_merge_bn(): + layers.append(nn.ReLU()) - class TestModel(nn.Module): - - def __init__(self): - super(TestModel, self).__init__() - - self.net = nn.Sequential(nn.Conv2d(16, 33, 3, stride=2), nn.BatchNorm2d(33), nn.ReLU()) + self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) - model = TestModel() + model = TestModel(dims) graph = symbolic_trace(model) graph = MergeBatchNorm().apply(graph) for m in graph.modules(): - assert not isinstance(m, nn.BatchNorm2d) + if dims == 1: + assert not isinstance(m, nn.BatchNorm1d) + elif dims == 2: + assert not isinstance(m, nn.BatchNorm2d) + else: + assert not isinstance(m, nn.BatchNorm3d) -def test_conv3d_merge_bn(): +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_avg_pool_to_quant_conv(dims): class TestModel(nn.Module): - def __init__(self): + def __init__(self, dims): super(TestModel, self).__init__() - self.net = nn.Sequential(nn.Conv3d(16, 33, 3, stride=2), nn.BatchNorm3d(33), nn.ReLU()) + if dims == 1: + self.net = nn.Sequential(nn.AvgPool1d(3, stride=2), nn.ReLU()) + elif dims == 2: + self.net = nn.Sequential(nn.AvgPool2d(3, stride=2), nn.ReLU()) + else: + self.net = nn.Sequential(nn.AvgPool3d(3, stride=2), nn.ReLU()) def forward(self, x): return self.net(x) - model = TestModel() + model = TestModel(dims) + + args = None + if dims == 1: + args = torch.randn(20, 16, 10) + elif dims == 2: + args = torch.randn(20, 16, 10, 50) + else: + args = torch.randn(20, 16, 10, 50, 100) + graph = symbolic_trace(model) - graph = MergeBatchNorm().apply(graph) + graph = AvgPoolToQuantDepthwiseConv().apply(graph, args) + has_quant_conv = False for m in graph.modules(): - assert not isinstance(m, nn.BatchNorm3d) + if isinstance(m, QuantConv1d): + has_quant_conv = True + if isinstance(m, QuantConv2d): + has_quant_conv = True + if isinstance(m, QuantConv3d): + has_quant_conv = True + + assert not isinstance(m, nn.AvgPool1d) + assert not isinstance(m, nn.AvgPool2d) + assert not isinstance(m, nn.AvgPool3d) + + assert has_quant_conv def test_rewriter_duplicate_shared_relu(): From 4828856f424701b7497a2a1d9a116faebee29ba1 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 12 Feb 2024 09:52:31 +0000 Subject: [PATCH 27/37] reordered items in SUPPORTED_CONV_OP --- src/brevitas/graph/gpxq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 870bf6d69..4b4485652 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -18,9 +18,9 @@ from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( - qnn.QuantConv3d, - qnn.QuantConv2d, qnn.QuantConv1d, + qnn.QuantConv2d, + qnn.QuantConv3d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d) From 82ea0daca726ee18577c2660df19a7e29f448ccc Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 12 Feb 2024 10:19:17 +0000 Subject: [PATCH 28/37] collapsed into isInstance --- tests/brevitas/graph/test_transforms.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index f54090280..c58d9d828 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -126,16 +126,10 @@ def forward(self, x): has_quant_conv = False for m in graph.modules(): - if isinstance(m, QuantConv1d): - has_quant_conv = True - if isinstance(m, QuantConv2d): - has_quant_conv = True - if isinstance(m, QuantConv3d): + if isinstance(m, (QuantConv1d, QuantConv2d, QuantConv3d)): has_quant_conv = True - assert not isinstance(m, nn.AvgPool1d) - assert not isinstance(m, nn.AvgPool2d) - assert not isinstance(m, nn.AvgPool3d) + assert not isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)) assert has_quant_conv From ea729c8760c9e4112d7990d3fd761698db492328 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:24:06 +0000 Subject: [PATCH 29/37] Revert "removed quantconv3d from flexml.py as it is unnecessary" This reverts commit f35cf4b1d60e68cd414e695650bdee6bd7e0dc63. --- src/brevitas/graph/target/flexml.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index 9aedd337c..35b9a4d14 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -45,6 +45,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.Conv3d: ( + qnn.QuantConv3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.ConvTranspose1d: ( qnn.QuantConvTranspose1d, { @@ -57,6 +63,12 @@ 'weight_quant': Int8WeightPerTensorFixedPoint, 'bias_quant': Int16Bias, 'return_quant_tensor': True}), + nn.ConvTranspose3d: ( + qnn.QuantConvTranspose3d, + { + 'weight_quant': Int8WeightPerTensorFixedPoint, + 'bias_quant': Int16Bias, + 'return_quant_tensor': True}), nn.BatchNorm1d: ( qnn.BatchNorm1dToQuantScaleBias, { From a8ad695489da9930b9930b3992b626ff3ce3dd33 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:34:26 +0000 Subject: [PATCH 30/37] correct incorrect value in Kernel3dApplHandlerMixin from 4 to 3 --- src/brevitas/export/onnx/handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py index 9f657f10c..469c49907 100644 --- a/src/brevitas/export/onnx/handler.py +++ b/src/brevitas/export/onnx/handler.py @@ -101,21 +101,21 @@ def padding(module): @staticmethod def stride(module): if isinstance(module.stride, int): - return [module.stride] * 4 + return [module.stride] * 3 else: return list(module.stride) @staticmethod def dilation(module): if isinstance(module.dilation, int): - return [module.dilation] * 4 + return [module.dilation] * 3 else: return list(module.dilation) @staticmethod def kernel_shape(module): if isinstance(module.kernel_size, int): - return [module.kernel_size] * 4 + return [module.kernel_size] * 3 else: return list(module.kernel_size) From 4c76e4f21e41ea475b4ce9b97ca0b514ecdc6bd3 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 08:42:07 +0000 Subject: [PATCH 31/37] updated comments for tensor shapes --- tests/brevitas/nn/test_a2q.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index d34bd6683..fa8dd701a 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -100,19 +100,19 @@ def test_quant_wbiol_a2q(model_input, current_cases): elif kwargs['model_type'] == 'QuantConv1d': # shape = (out_channels, in_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2)) elif kwargs[ - 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_size, kernel_size) + 'model_type'] == 'QuantConv2d': # shape = (out_channels, in_channels, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) elif kwargs[ - 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_size, kernel_size) + 'model_type'] == 'QuantConv3d': # shape = (out_channels, in_channels, kernel_depth, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3, 4)) elif kwargs[ 'model_type'] == 'QuantConvTranspose1d': # shape = (in_channels, out_channels, kernel_size) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2)) elif kwargs[ - 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_size) + 'model_type'] == 'QuantConvTranspose2d': # shape = (in_channels, out_channels, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3)) elif kwargs[ - 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_size) + 'model_type'] == 'QuantConvTranspose3d': # shape = (in_channels, out_channels, kernel_depth, kernel_height, kernel_width) quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(0, 2, 3, 4)) else: raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") From 5b94089baa50fc6ee690bb7722c0b732630ca218 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:38:44 +0000 Subject: [PATCH 32/37] updated convtranspose method based on PR suggestion --- src/brevitas/nn/quant_convtranspose.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 0a031b517..ab7e1cafe 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -118,8 +118,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // self.stride[0]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -215,9 +215,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // + self.stride[0]) * (self.kernel_size[1] // self.stride[1]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -313,9 +313,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) - overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) - overlapping_sums *= max(round(self.kernel_size[2] / self.stride[2]), 1) - max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size + patch_size = (self.kernel_size[0] // self.stride[0]) * ( + self.kernel_size[1] // self.stride[1]) * (self.kernel_size[2] // self.stride[2]) + max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width From b607c8fe8114b7004af212abdc933e557540af1c Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 14:29:22 +0000 Subject: [PATCH 33/37] added max(...,1) to patch_size calculation --- src/brevitas/nn/quant_convtranspose.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index ab7e1cafe..e3393947a 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -118,7 +118,7 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = (self.kernel_size[0] // self.stride[0]) + patch_size = max(self.kernel_size[0] // self.stride[0], 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -215,8 +215,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = (self.kernel_size[0] // - self.stride[0]) * (self.kernel_size[1] // self.stride[1]) + patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max( + self.kernel_size[1] // self.stride[1], 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -313,8 +313,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = (self.kernel_size[0] // self.stride[0]) * ( - self.kernel_size[1] // self.stride[1]) * (self.kernel_size[2] // self.stride[2]) + patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max( + self.kernel_size[1] // self.stride[1], 1) * max( + self.kernel_size[2] // self.stride[2], 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width From 6271f4025c04cca5f27664f5187fe1760a00653f Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:28:36 +0000 Subject: [PATCH 34/37] added some basic tests for convtranspose --- tests/brevitas/nn/test_convtranspose.py | 69 +++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/brevitas/nn/test_convtranspose.py diff --git a/tests/brevitas/nn/test_convtranspose.py b/tests/brevitas/nn/test_convtranspose.py new file mode 100644 index 000000000..9b442ad77 --- /dev/null +++ b/tests/brevitas/nn/test_convtranspose.py @@ -0,0 +1,69 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn + +from brevitas.nn import QuantConvTranspose1d +from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d + + +def test_quantconvtranspose1d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 50) + + normal = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose1d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() + + +def test_quantconvtranspose2d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 50, 100) + + normal = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose2d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() + + +def test_quantconvtranspose3d(): + in_channels = 16 + out_channels = 4 + kernel_size = 3 + + input = torch.ones(10, in_channels, 10, 50, 100) + + normal = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=2) + normal_output = normal(input) + + quant = QuantConvTranspose3d(in_channels, out_channels, kernel_size, stride=2) + + # re-using weight and bias so the layers should give the same results + quant.weight = normal.weight + quant.bias = normal.bias + quant_output = quant(input) + + assert torch.isclose(normal_output, quant_output, atol=0.01).all().item() From 6a651c359060ae94116f0090e4f813a7f63f4d11 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:38:49 +0000 Subject: [PATCH 35/37] updated copyright year --- tests/brevitas/nn/test_convtranspose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/nn/test_convtranspose.py b/tests/brevitas/nn/test_convtranspose.py index 9b442ad77..deb585db0 100644 --- a/tests/brevitas/nn/test_convtranspose.py +++ b/tests/brevitas/nn/test_convtranspose.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause import torch From f4f92e686963e0b7836a3871908cb4d26602c0d5 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:39:39 +0000 Subject: [PATCH 36/37] changed rounding from floor equivalent to ceil --- src/brevitas/nn/quant_convtranspose.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index e3393947a..7c81066dc 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -118,7 +118,7 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(self.kernel_size[0] // self.stride[0], 1) + patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -215,8 +215,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max( - self.kernel_size[1] // self.stride[1], 1) + patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + torch.ceil(self.kernel_size[1] / self.stride[1]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -313,9 +313,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max( - self.kernel_size[1] // self.stride[1], 1) * max( - self.kernel_size[2] // self.stride[2], 1) + patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + torch.ceil(self.kernel_size[1] / self.stride[1]), 1) * max( + torch.ceil(self.kernel_size[2] / self.stride[2]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width From 3e5db0af4eadd37c714ef490bfffd828697f2d65 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:22:03 +0000 Subject: [PATCH 37/37] switched torch.ceil to math.ceil --- src/brevitas/nn/quant_convtranspose.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 7c81066dc..75dd90378 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import math from typing import Optional, Tuple, Type, Union from packaging import version @@ -118,7 +119,7 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -215,8 +216,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( - torch.ceil(self.kernel_size[1] / self.stride[1]), 1) + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + math.ceil(self.kernel_size[1] / self.stride[1]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -313,9 +314,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups - patch_size = max(torch.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( - torch.ceil(self.kernel_size[1] / self.stride[1]), 1) * max( - torch.ceil(self.kernel_size[2] / self.stride[2]), 1) + patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max( + math.ceil(self.kernel_size[1] / self.stride[1]), 1) * max( + math.ceil(self.kernel_size[2] / self.stride[2]), 1) max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width