Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:Xilinx/brevitas into feat/remove_qop…
Browse files Browse the repository at this point in the history
…_export
  • Loading branch information
costigt-dev committed Mar 27, 2024
2 parents 91e6248 + 44cb08b commit 1dff353
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 45 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements-export.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
onnx
onnx==1.15
onnxoptimizer
2 changes: 1 addition & 1 deletion requirements/requirements-finn-integration.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
bitstring
onnx
onnx==1.15
onnxoptimizer
onnxruntime>=1.15.0
qonnx
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-ort-integration.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
onnx
onnx==1.15
onnxoptimizer
onnxruntime>=1.15.0
qonnx
19 changes: 2 additions & 17 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,10 @@ def __init__(
output_size: Union[int, Tuple[int, int]],
trunc_quant: Optional[AccQuantType] = RoundTo8bit,
return_quant_tensor: bool = True,
cache_kernel_size_stride: bool = True,
**kwargs):
AdaptiveAvgPool2d.__init__(self, output_size=output_size)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_kernel_size_stride = cache_kernel_size_stride
self._cached_kernel_size = None
self._cached_kernel_stride = None

@property
def channelwise_separable(self) -> bool:
Expand All @@ -114,15 +110,8 @@ def requires_export_handler(self):
def padding(self):
return 0

@property
def kernel_size(self):
return self._cached_kernel_size

@property
def stride(self):
return self._cached_kernel_stride

def compute_kernel_size_stride(self, input_shape, output_shape):
@staticmethod
def compute_kernel_size_stride(input_shape, output_shape):
kernel_size_list = []
stride_list = []
for inp, out in zip(input_shape, output_shape):
Expand All @@ -141,10 +130,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if self.cache_kernel_size_stride:
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
Expand Down
37 changes: 12 additions & 25 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

class QuantTensorBase(NamedTuple):
value: Tensor
scale: Optional[Tensor]
zero_point: Optional[Tensor]
bit_width: Optional[Tensor]
signed_t: Optional[Tensor]
training_t: Optional[Tensor]
scale: Tensor
zero_point: Tensor
bit_width: Tensor
signed_t: Tensor
training_t: Tensor


def _unpack_quant_tensor(input_data):
Expand Down Expand Up @@ -61,17 +61,11 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training):

@property
def signed(self):
if self.signed_t is not None:
return self.signed_t.item()
else:
return None
return self.signed_t.item()

@property
def training(self):
if self.training_t is not None:
return self.training_t.item()
else:
return None
return self.training_t.item()

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down Expand Up @@ -129,8 +123,7 @@ def device(self):
value_device = self.value.device
is_same_device = True
for t in [self.scale, self.zero_point, self.bit_width]:
if t is not None:
is_same_device &= value_device == t.device
is_same_device &= value_device == t.device
if not is_same_device:
raise RuntimeError("Value and metadata are on different devices")
return value_device
Expand Down Expand Up @@ -193,13 +186,13 @@ def is_zero_zero_point(tensor):
return (tensor.zero_point == 0.).all()

def check_scaling_factors_same(self, other):
if self.training is not None and self.training:
if self.training:
return True
if not torch.allclose(self.scale, other.scale):
raise RuntimeError("Scaling factors are different")

def check_zero_points_same(self, other):
if self.training is not None and self.training:
if self.training:
return True
if not torch.allclose(self.zero_point, other.zero_point):
raise RuntimeError("Zero points are different")
Expand All @@ -226,7 +219,7 @@ def transpose(self, *args, **kwargs):
tensor_meta = {
'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width}
for k, tm in tensor_meta.items():
if tm is not None and len(value.shape) == len(tm.shape):
if len(value.shape) == len(tm.shape):
tensor_meta[k] = tm.transpose(*args, **kwargs)
return self.set(value=value, **tensor_meta)

Expand All @@ -235,7 +228,7 @@ def permute(self, *args, **kwargs):
tensor_meta = {
'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width}
for k, tm in tensor_meta.items():
if tm is not None and len(value.shape) == len(tm.shape):
if len(value.shape) == len(tm.shape):
tensor_meta[k] = tm.permute(*args, **kwargs)
return self.set(value=value, **tensor_meta)

Expand Down Expand Up @@ -359,8 +352,6 @@ def __add__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value + other.value
else:
output = self.value + other
return output
Expand Down Expand Up @@ -389,8 +380,6 @@ def __mul__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value * other.value
else:
output = self.value * other
return output
Expand Down Expand Up @@ -420,8 +409,6 @@ def __truediv__(self, other):
bit_width=output_bit_width,
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = self.value / other.value
else:
output = self.value / other
return output
Expand Down

0 comments on commit 1dff353

Please sign in to comment.