From 7196dc675d96b54acbbfed1a305be59413569c7e Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:44:22 +0200 Subject: [PATCH] Fix (nn): add missing support for padding_mode (#709) --- src/brevitas/nn/quant_conv.py | 55 ++++++++++++-------------- src/brevitas/nn/quant_convtranspose.py | 4 ++ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index fd6ab2e39..81aaf86f2 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -9,7 +9,6 @@ from torch.nn import Conv1d from torch.nn import Conv2d from torch.nn import functional as F -from torch.nn.functional import conv2d from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste @@ -35,8 +34,8 @@ def __init__( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, + padding_mode: str = 'zeros', bias: bool = True, - padding_type: str = 'standard', weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, @@ -45,6 +44,12 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False Conv1d.__init__( self, in_channels=in_channels, @@ -54,6 +59,7 @@ def __init__( padding=padding, dilation=dilation, groups=groups, + padding_mode=padding_mode, bias=bias, device=device, dtype=dtype) @@ -65,9 +71,7 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) - assert self.padding_mode == 'zeros' - assert not (padding_type == 'same' and padding != 0) - self.padding_type = padding_type + self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -84,11 +88,7 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels - def conv1d_zeros_pad(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]): - out = F.conv1d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) - return out - - def conv1d_same_zeros_pad(self, x, weight, bias): + def conv1d_same_zeros_pad_stride(self, x, weight, bias): ih = x.size()[-1] kh = weight.size()[-1] sh = self.stride[0] @@ -103,12 +103,10 @@ def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTenso return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - if self.padding_type == 'standard': - return self.conv1d_zeros_pad(x, quant_weight, quant_bias) - elif self.padding_type == 'same': - return self.conv1d_same_zeros_pad(x, quant_weight, quant_bias) + if self.is_same_padded_strided: + return self.conv1d_same_zeros_pad_stride(x, quant_weight, quant_bias) else: - raise NotImplementedError(f"Padding type {self.padding_type} not supported.") + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) @@ -130,8 +128,8 @@ def __init__( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, + padding_mode: str = 'zeros', bias: bool = True, - padding_type: str = 'standard', weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, input_quant: Optional[ActQuantType] = None, @@ -140,6 +138,12 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> None: + # avoid an init error in the super class by setting padding to 0 + if padding_mode == 'zeros' and padding == 'same' and stride > 1: + padding = 0 + is_same_padded_strided = True + else: + is_same_padded_strided = False Conv2d.__init__( self, in_channels=in_channels, @@ -147,6 +151,7 @@ def __init__( kernel_size=kernel_size, stride=stride, padding=padding, + padding_mode=padding_mode, dilation=dilation, groups=groups, bias=bias, @@ -160,9 +165,7 @@ def __init__( output_quant=output_quant, return_quant_tensor=return_quant_tensor, **kwargs) - assert self.padding_mode == 'zeros' - assert not (padding_type == 'same' and padding != 0) - self.padding_type = padding_type + self.is_same_padded_strided = is_same_padded_strided @property def per_elem_ops(self): @@ -179,11 +182,7 @@ def output_channel_dim(self): def channelwise_separable(self) -> bool: return self.groups == self.in_channels - def conv2d_zeros_pad(self, x: Tensor, weight: Tensor, bias: Tensor): - out = conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) - return out - - def conv2d_same_zeros_pad(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]): + def conv2d_same_zeros_pad_stride(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]): ih, iw = x.size()[-2:] kh, kw = weight.size()[-2:] sh, sw = self.stride @@ -199,12 +198,10 @@ def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTenso return self.forward_impl(input) def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): - if self.padding_type == 'standard': - return self.conv2d_zeros_pad(x, quant_weight, quant_bias) - elif self.padding_type == 'same': - return self.conv2d_same_zeros_pad(x, quant_weight, quant_bias) + if self.is_same_padded_strided: + return self.conv2d_same_zeros_pad_stride(x, quant_weight, quant_bias) else: - raise RuntimeError(f"Padding type {self.padding_type} not supported.") + return self._conv_forward(x, quant_weight, quant_bias) def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) diff --git a/src/brevitas/nn/quant_convtranspose.py b/src/brevitas/nn/quant_convtranspose.py index 7aa997f7a..c5dbd52b9 100644 --- a/src/brevitas/nn/quant_convtranspose.py +++ b/src/brevitas/nn/quant_convtranspose.py @@ -37,6 +37,7 @@ def __init__( output_padding: int = 0, dilation: int = 1, groups: int = 1, + padding_mode: str = 'zeros', bias: bool = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, @@ -56,6 +57,7 @@ def __init__( output_padding=output_padding, dilation=dilation, groups=groups, + padding_mode=padding_mode, bias=bias, device=device, dtype=dtype) @@ -132,6 +134,7 @@ def __init__( output_padding: Union[int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, groups: int = 1, + padding_mode: str = 'zeros', bias: bool = True, weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, bias_quant: Optional[BiasQuantType] = None, @@ -151,6 +154,7 @@ def __init__( output_padding=output_padding, dilation=dilation, groups=groups, + padding_mode=padding_mode, bias=bias, device=device, dtype=dtype)