From e45b0c3170e6f60043a2e36d7b3cfe4c7d91199d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 22 Mar 2024 10:12:25 +0100 Subject: [PATCH 1/4] Fix (quant_tensor): fix typing and remove unused checks (#913) --- src/brevitas/quant_tensor/__init__.py | 31 +++++++++++---------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index c66690c50..bfbe81446 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -22,11 +22,11 @@ class QuantTensorBase(NamedTuple): value: Tensor - scale: Optional[Tensor] - zero_point: Optional[Tensor] - bit_width: Optional[Tensor] - signed_t: Optional[Tensor] - training_t: Optional[Tensor] + scale: Tensor + zero_point: Tensor + bit_width: Tensor + signed_t: Tensor + training_t: Tensor def _unpack_quant_tensor(input_data): @@ -61,17 +61,11 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): @property def signed(self): - if self.signed_t is not None: - return self.signed_t.item() - else: - return None + return self.signed_t.item() @property def training(self): - if self.training_t is not None: - return self.training_t.item() - else: - return None + return self.training_t.item() def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: @@ -129,8 +123,7 @@ def device(self): value_device = self.value.device is_same_device = True for t in [self.scale, self.zero_point, self.bit_width]: - if t is not None: - is_same_device &= value_device == t.device + is_same_device &= value_device == t.device if not is_same_device: raise RuntimeError("Value and metadata are on different devices") return value_device @@ -193,13 +186,13 @@ def is_zero_zero_point(tensor): return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.scale, other.scale): raise RuntimeError("Scaling factors are different") def check_zero_points_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.zero_point, other.zero_point): raise RuntimeError("Zero points are different") @@ -226,7 +219,7 @@ def transpose(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.transpose(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -235,7 +228,7 @@ def permute(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.permute(*args, **kwargs) return self.set(value=value, **tensor_meta) From f0f8c560f99517f3474cb6789e3959b2dbfb9249 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 22 Mar 2024 14:57:48 +0100 Subject: [PATCH 2/4] Fix (nn): removed unused caching in quant adaptive avgpool2d (#911) --- src/brevitas/nn/quant_avg_pool.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 5d567d0ca..554504908 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -93,14 +93,10 @@ def __init__( output_size: Union[int, Tuple[int, int]], trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, - cache_kernel_size_stride: bool = True, **kwargs): AdaptiveAvgPool2d.__init__(self, output_size=output_size) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) - self.cache_kernel_size_stride = cache_kernel_size_stride - self._cached_kernel_size = None - self._cached_kernel_stride = None @property def channelwise_separable(self) -> bool: @@ -114,15 +110,8 @@ def requires_export_handler(self): def padding(self): return 0 - @property - def kernel_size(self): - return self._cached_kernel_size - - @property - def stride(self): - return self._cached_kernel_stride - - def compute_kernel_size_stride(self, input_shape, output_shape): + @staticmethod + def compute_kernel_size_stride(input_shape, output_shape): kernel_size_list = [] stride_list = [] for inp, out in zip(input_shape, output_shape): @@ -141,10 +130,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if self.cache_kernel_size_stride: - self._cached_kernel_size = k_size - self._cached_kernel_stride = stride - if isinstance(x, QuantTensor): y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) From 3f7a36ff5dbe1916ff88ea509ca5238da3e0a4bf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 23 Mar 2024 11:57:25 +0100 Subject: [PATCH 3/4] Fix (quant_tensor): remove unused checks (#918) --- src/brevitas/quant_tensor/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index bfbe81446..569ed71e0 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -352,8 +352,6 @@ def __add__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value + other.value else: output = self.value + other return output @@ -382,8 +380,6 @@ def __mul__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value * other.value else: output = self.value * other return output @@ -413,8 +409,6 @@ def __truediv__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value / other.value else: output = self.value / other return output From 44cb08bcde9b9522a62629bf0ba6839dc2d17ef0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 27 Mar 2024 08:51:33 +0100 Subject: [PATCH 4/4] Setup: pin ONNX to 1.15 due to ORT incompatibility (#924) --- requirements/requirements-export.txt | 2 +- requirements/requirements-finn-integration.txt | 2 +- requirements/requirements-ort-integration.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/requirements-export.txt b/requirements/requirements-export.txt index 2432307a3..62d052509 100644 --- a/requirements/requirements-export.txt +++ b/requirements/requirements-export.txt @@ -1,2 +1,2 @@ -onnx +onnx==1.15 onnxoptimizer diff --git a/requirements/requirements-finn-integration.txt b/requirements/requirements-finn-integration.txt index 33a515466..68e8e5a5b 100644 --- a/requirements/requirements-finn-integration.txt +++ b/requirements/requirements-finn-integration.txt @@ -1,5 +1,5 @@ bitstring -onnx +onnx==1.15 onnxoptimizer onnxruntime>=1.15.0 qonnx diff --git a/requirements/requirements-ort-integration.txt b/requirements/requirements-ort-integration.txt index afc10d07b..7b83afc3f 100644 --- a/requirements/requirements-ort-integration.txt +++ b/requirements/requirements-ort-integration.txt @@ -1,4 +1,4 @@ -onnx +onnx==1.15 onnxoptimizer onnxruntime>=1.15.0 qonnx