diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index c66690c50..bfbe81446 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -22,11 +22,11 @@ class QuantTensorBase(NamedTuple): value: Tensor - scale: Optional[Tensor] - zero_point: Optional[Tensor] - bit_width: Optional[Tensor] - signed_t: Optional[Tensor] - training_t: Optional[Tensor] + scale: Tensor + zero_point: Tensor + bit_width: Tensor + signed_t: Tensor + training_t: Tensor def _unpack_quant_tensor(input_data): @@ -61,17 +61,11 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): @property def signed(self): - if self.signed_t is not None: - return self.signed_t.item() - else: - return None + return self.signed_t.item() @property def training(self): - if self.training_t is not None: - return self.training_t.item() - else: - return None + return self.training_t.item() def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: @@ -129,8 +123,7 @@ def device(self): value_device = self.value.device is_same_device = True for t in [self.scale, self.zero_point, self.bit_width]: - if t is not None: - is_same_device &= value_device == t.device + is_same_device &= value_device == t.device if not is_same_device: raise RuntimeError("Value and metadata are on different devices") return value_device @@ -193,13 +186,13 @@ def is_zero_zero_point(tensor): return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.scale, other.scale): raise RuntimeError("Scaling factors are different") def check_zero_points_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.zero_point, other.zero_point): raise RuntimeError("Zero points are different") @@ -226,7 +219,7 @@ def transpose(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.transpose(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -235,7 +228,7 @@ def permute(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.permute(*args, **kwargs) return self.set(value=value, **tensor_meta)