diff --git a/requirements/requirements-export.txt b/requirements/requirements-export.txt index 2432307a3..62d052509 100644 --- a/requirements/requirements-export.txt +++ b/requirements/requirements-export.txt @@ -1,2 +1,2 @@ -onnx +onnx==1.15 onnxoptimizer diff --git a/requirements/requirements-finn-integration.txt b/requirements/requirements-finn-integration.txt index 33a515466..68e8e5a5b 100644 --- a/requirements/requirements-finn-integration.txt +++ b/requirements/requirements-finn-integration.txt @@ -1,5 +1,5 @@ bitstring -onnx +onnx==1.15 onnxoptimizer onnxruntime>=1.15.0 qonnx diff --git a/requirements/requirements-ort-integration.txt b/requirements/requirements-ort-integration.txt index afc10d07b..7b83afc3f 100644 --- a/requirements/requirements-ort-integration.txt +++ b/requirements/requirements-ort-integration.txt @@ -1,4 +1,4 @@ -onnx +onnx==1.15 onnxoptimizer onnxruntime>=1.15.0 qonnx diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 5d567d0ca..554504908 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -93,14 +93,10 @@ def __init__( output_size: Union[int, Tuple[int, int]], trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, - cache_kernel_size_stride: bool = True, **kwargs): AdaptiveAvgPool2d.__init__(self, output_size=output_size) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) - self.cache_kernel_size_stride = cache_kernel_size_stride - self._cached_kernel_size = None - self._cached_kernel_stride = None @property def channelwise_separable(self) -> bool: @@ -114,15 +110,8 @@ def requires_export_handler(self): def padding(self): return 0 - @property - def kernel_size(self): - return self._cached_kernel_size - - @property - def stride(self): - return self._cached_kernel_stride - - def compute_kernel_size_stride(self, input_shape, output_shape): + @staticmethod + def compute_kernel_size_stride(input_shape, output_shape): kernel_size_list = [] stride_list = [] for inp, out in zip(input_shape, output_shape): @@ -141,10 +130,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if self.cache_kernel_size_stride: - self._cached_kernel_size = k_size - self._cached_kernel_stride = stride - if isinstance(x, QuantTensor): y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index c66690c50..569ed71e0 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -22,11 +22,11 @@ class QuantTensorBase(NamedTuple): value: Tensor - scale: Optional[Tensor] - zero_point: Optional[Tensor] - bit_width: Optional[Tensor] - signed_t: Optional[Tensor] - training_t: Optional[Tensor] + scale: Tensor + zero_point: Tensor + bit_width: Tensor + signed_t: Tensor + training_t: Tensor def _unpack_quant_tensor(input_data): @@ -61,17 +61,11 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): @property def signed(self): - if self.signed_t is not None: - return self.signed_t.item() - else: - return None + return self.signed_t.item() @property def training(self): - if self.training_t is not None: - return self.training_t.item() - else: - return None + return self.training_t.item() def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: @@ -129,8 +123,7 @@ def device(self): value_device = self.value.device is_same_device = True for t in [self.scale, self.zero_point, self.bit_width]: - if t is not None: - is_same_device &= value_device == t.device + is_same_device &= value_device == t.device if not is_same_device: raise RuntimeError("Value and metadata are on different devices") return value_device @@ -193,13 +186,13 @@ def is_zero_zero_point(tensor): return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.scale, other.scale): raise RuntimeError("Scaling factors are different") def check_zero_points_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.zero_point, other.zero_point): raise RuntimeError("Zero points are different") @@ -226,7 +219,7 @@ def transpose(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.transpose(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -235,7 +228,7 @@ def permute(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.permute(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -359,8 +352,6 @@ def __add__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value + other.value else: output = self.value + other return output @@ -389,8 +380,6 @@ def __mul__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value * other.value else: output = self.value * other return output @@ -420,8 +409,6 @@ def __truediv__(self, other): bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value / other.value else: output = self.value / other return output