From 2d3875fa07b95e3aa1619e6945c06b89ffa1981e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 1 Feb 2024 18:56:37 +0000 Subject: [PATCH 01/32] No Empty QT --- src/brevitas/core/quant/binary.py | 3 +- src/brevitas/core/stats/stats_op.py | 2 +- .../onnx/standard/qoperator/handler/base.py | 2 +- src/brevitas/graph/calibrate.py | 2 +- src/brevitas/nn/mixin/base.py | 37 +++++---- src/brevitas/nn/mixin/parameter.py | 6 +- src/brevitas/nn/quant_layer.py | 78 +++++++++++++------ src/brevitas/nn/quant_rnn.py | 70 +++++++++-------- src/brevitas/proxy/parameter_quant.py | 4 +- src/brevitas/proxy/runtime_quant.py | 10 ++- src/brevitas/quant_tensor/__init__.py | 37 ++++++++- tests/brevitas/nn/test_nn_quantizers.py | 21 ++++- tests/brevitas_ort/common.py | 11 +-- 13 files changed, 191 insertions(+), 92 deletions(-) diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 392cdeb62..3a4b7346e 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -48,8 +48,9 @@ class BinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0): super(BinaryQuant, self).__init__() + assert signed, "Unsigned binary quant not supported" self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index bfcfbb58f..dee6011d5 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -479,7 +479,7 @@ def evaluate_loss(self, x, candidate): self.set_local_loss_mode(True) quant_value = self.proxy_forward(x) if isinstance(quant_value, tuple): - quant_value = quant_value[0] + quant_value = quant_value.value loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) return loss diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index 2dfcf6037..e614d2ed5 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module): @classmethod def input_dequant_symbolic_kwargs(cls, module): - if module._cached_inp.scale is not None: + if module._cached_inp is not None: return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) else: return None diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index bb435b7ef..79cedff7f 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -168,7 +168,7 @@ def disable_act_quant_hook(self, module, inp, output): if isinstance(module.tracked_module_list[0], QuantHardTanh): inp = F.hardtanh( inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val) - return QuantTensor(value=inp, training=module.training) + return inp def disable_act_quantization(self, model, is_training): # If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 2d4fa97ad..12a252398 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,6 +18,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape +from brevitas.quant_tensor import _is_all_nested_not_none from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs @@ -166,15 +167,17 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp - else: - inp = QuantTensor(inp, training=self.training) - if not self.training and self.cache_inference_quant_inp: - cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - self._cached_inp = cached_inp + # else: + # if not self.training and self.cache_inference_quant_inp: + # cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) + # self._cached_inp = cached_inp # Remove any naming metadata to avoid dowmstream errors # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): - inp = inp.set(value=inp.value.rename(None)) + if isinstance(inp, QuantTensor): + inp = inp.set(qt_value=inp.qt_value.rename(None)) + else: + inp = inp.rename(None) return inp def pack_output(self, quant_output: QuantTensor): @@ -184,7 +187,10 @@ def pack_output(self, quant_output: QuantTensor): if self.return_quant_tensor: return quant_output else: - return quant_output.value + if isinstance(quant_output, QuantTensor): + return quant_output.value + else: + return quant_output class QuantRecurrentLayerMixin(ExportMixin): @@ -246,9 +252,10 @@ def gate_params_fwd(gate, quant_input): acc_bit_width = None quant_weight_ih = gate.input_weight() quant_weight_hh = gate.hidden_weight() - if quant_input.bit_width is not None: + if isinstance(quant_input, QuantTensor): acc_bit_width = None # TODO - if quant_input.scale is not None and quant_weight_ih.scale is not None: + if getattr(quant_input, 'scale', None) is not None and getattr( + quant_weight_ih, 'scale', None) is not None: acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) acc_scale = quant_weight_ih.scale.view(acc_scale_shape) acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) @@ -267,8 +274,8 @@ def maybe_quantize_input(self, inp): quant_input = inp if not self.quantize_output_only: quant_input = self.io_quant(quant_input) - elif not isinstance(inp, QuantTensor): - quant_input = QuantTensor(quant_input) + # elif not isinstance(inp, QuantTensor): + # quant_input = QuantTensor(quant_input) return quant_input def maybe_quantize_state(self, inp, state, quant): @@ -276,7 +283,7 @@ def maybe_quantize_state(self, inp, state, quant): batch_size = inp.size(0) if self.cell.batch_first else inp.size(1) quant_state = torch.zeros( int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device) - quant_state = QuantTensor(quant_state) + # quant_state = QuantTensor(quant_state) else: quant_state = quant(state) return quant_state @@ -303,7 +310,8 @@ def pack_quant_outputs(self, quant_outputs): quant_output[2], quant_output[3], self.io_quant.is_signed, - self.training) for quant_output in quant_outputs] + self.training, + _allow_empty=True) for quant_output in quant_outputs] else: outputs = [torch.unsqueeze(o[0], dim=seq_dim) for o in quant_outputs] if self.reverse_input: @@ -331,7 +339,8 @@ def pack_quant_state(self, quant_state, quant): quant_state[2], quant_state[3], quant.is_signed, - self.training) + training=self.training, + _allow_empty=True) else: quant_state = torch.unsqueeze(quant_state[0], dim=0) return quant_state diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 095c981f1..3acfe7c95 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -198,7 +198,11 @@ def quant_bias_zero_point(self): if self.bias is None: return None if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: - return self.bias_quant(self.bias).zero_point + bias_quant = self.bias_quant(self.bias) + if isinstance(bias_quant, QuantTensor): + return bias_quant.zero_point + else: + return None else: if self._cached_bias is None: raise RuntimeError( diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 7208aa8e3..f20b93f65 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -19,6 +19,10 @@ from .utils import rename_state_dict_by_prefix +def return_value(tensor): + return tensor.value if isinstance(tensor, QuantTensor) else tensor + + class QuantNonLinearActLayer(QuantNonLinearActMixin, QuantInputMixin, QuantLayerMixin, Module): __metaclass__ = ABCMeta @@ -303,61 +307,89 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(inp.value) + inp_value = getattr(inp, 'value', inp) + out = self.export_handler(inp_value) self._set_global_is_quant_layer(False) return out quant_input = self.input_quant(inp) + # quant_input_value = getattr(quant_input, 'value', quant_input) + # quant_input_scale = getattr(quant_input, 'scale', None) + # quant_input_bitwidth = getattr(quant_input, 'bit_width', None) + quant_weight = self.quant_weight(quant_input) + # quant_weight_value = getattr(quant_weight, 'value', quant_weight) + # quant_weight_scale = getattr(quant_weight, 'scale', None) + # quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None) + compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( + quant_weight, QuantTensor) + if not (compute_output_quant_tensor or + self.is_output_quant_enabled) and self.return_quant_tensor: + raise RuntimeError("QuantLayer is not correctly configured") if (self.return_quant_tensor or (self.is_bias_quant_enabled and (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))): - if quant_input.bit_width is not None and quant_weight.bit_width is not None: + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): output_bit_width = self.max_acc_bit_width( quant_input.bit_width, quant_weight.bit_width) - if quant_input.scale is not None and quant_weight.scale is not None: + output_scale = self.quant_output_scale_impl( inp, quant_input.scale, quant_weight.scale) - if quant_input.signed is not None: - output_signed = inp.signed or quant_weight.signed + + quant_input_signed = quant_input.signed if isinstance( + quant_input, QuantTensor) else True + quant_weight_signed = quant_weight.signed if isinstance( + quant_weight, QuantTensor) else True + output_signed = quant_input_signed or quant_weight_signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) + quant_bias_value = getattr(quant_bias, 'value', quant_bias) + quant_bias_scale = getattr(quant_bias, 'scale', None) + quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None) if not self.training and self.cache_inference_quant_bias: self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) - output_tensor = self.inner_forward_impl( - quant_input.value, quant_weight.value, quant_bias.value) + return_value(quant_input), return_value(quant_weight), return_value(quant_bias)) if (self.return_quant_tensor and output_scale is not None and - (quant_bias.scale is None or - (quant_bias.scale is not None and - quant_bias.scale.data_ptr() != output_scale.data_ptr()))): - output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1) - output_zero_point = -quant_bias.value.view( + (quant_bias_scale is None or + (quant_bias_scale is not None and + quant_bias_scale.data_ptr() != output_scale.data_ptr()))): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -quant_bias_value.view( output_scale_broadcast_shape) / output_scale - if quant_bias.bit_width is not None and output_bit_width is not None: + if hasattr(quant_bias, 'bit_width' + ) and quant_bias_bitwidth is not None and output_bit_width is not None: output_bit_width = torch.where( - quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) + quant_bias_bitwidth > output_bit_width, quant_bias_bitwidth, output_bit_width) output_bit_width = output_bit_width + 1 else: - output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None) + output_tensor = self.inner_forward_impl( + return_value(quant_input), return_value(quant_weight), None) if self.return_quant_tensor and not self.is_output_quant_enabled: - if (quant_input.zero_point is not None and quant_weight.zero_point is not None and + if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and ((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())): raise RuntimeError("Computing zero point of output accumulator not supported yet.") elif quant_input.zero_point is not None and output_zero_point is None: output_zero_point = quant_input.zero_point + elif self.return_quant_tensor and output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output_tensor) - quant_output = QuantTensor( - value=output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=self.training) + if not self.return_quant_tensor or not compute_output_quant_tensor: + quant_output = output_tensor + else: + quant_output = QuantTensor.from_fake_quantized( + output_tensor, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=self.training) quant_output = self.output_quant(quant_output) return self.pack_output(quant_output) diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 642c1b2d1..6e7ec581a 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,6 +23,8 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -416,11 +418,12 @@ def forward(self, inp, state): quant_input = self.maybe_quantize_input(inp) quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) - if quant_bias.value is None: - quant_bias = torch.tensor(0., device=quant_input.value.device) + quant_input_value = _unpack_quant_tensor(quant_input) + if getattr(quant_bias, 'value', quant_bias) is None: + quant_bias = torch.tensor(0., device=quant_input_value.device) else: - quant_bias = quant_bias.value - quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant) + quant_bias = _unpack_quant_tensor(quant_bias) + quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant) if self.export_mode: cell = self.export_handler elif self.fast_mode: @@ -428,10 +431,10 @@ def forward(self, inp, state): else: cell = self.cell quant_outputs = cell( - quant_input.value, - quant_state.value, - quant_weight_ih.value, - quant_weight_hh.value, + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_state), + _unpack_quant_tensor(quant_weight_ih), + _unpack_quant_tensor(quant_weight_hh), quant_bias) quant_output = self.pack_quant_outputs(quant_outputs) quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant) @@ -666,6 +669,7 @@ def fast_cell(self): def forward(self, inp, hidden_state, cell_state): quant_input = self.maybe_quantize_input(inp) + quant_input_value = _unpack_quant_tensor(quant_input) quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd( self.input_gate_params, quant_input) quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd( @@ -680,26 +684,26 @@ def forward(self, inp, hidden_state, cell_state): quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input) # Handle None bias by setting it 0. - if quant_bias_input.value is None: - quant_bias_input = torch.tensor(0., device=quant_input.value.device) + if getattr(quant_bias_input, 'value', quant_bias_input) is None: + quant_bias_input = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_input = quant_bias_input.value - if quant_bias_forget.value is None: - quant_bias_forget = torch.tensor(0., device=quant_input.value.device) + quant_bias_input = _unpack_quant_tensor(quant_bias_input) + if getattr(quant_bias_forget, 'value', quant_bias_forget) is None: + quant_bias_forget = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_forget = quant_bias_forget.value - if quant_bias_cell.value is None: - quant_bias_cell = torch.tensor(0., device=quant_input.value.device) + quant_bias_forget = _unpack_quant_tensor(quant_bias_forget) + if getattr(quant_bias_cell, 'value', quant_bias_cell) is None: + quant_bias_cell = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_cell = quant_bias_cell.value - if quant_bias_output.value is None: - quant_bias_output = torch.tensor(0., device=quant_input.value.device) + quant_bias_cell = _unpack_quant_tensor(quant_bias_cell) + if getattr(quant_bias_output, 'value', quant_bias_output) is None: + quant_bias_output = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_output = quant_bias_output.value + quant_bias_output = _unpack_quant_tensor(quant_bias_output) quant_hidden_state = self.maybe_quantize_state( - quant_input.value, hidden_state, self.cell.output_quant) + quant_input_value, hidden_state, self.cell.output_quant) quant_cell_state = self.maybe_quantize_state( - quant_input.value, cell_state, self.cell.cell_state_quant) + quant_input_value, cell_state, self.cell.cell_state_quant) # Pick cell impl if self.export_mode: cell = self.export_handler @@ -708,17 +712,17 @@ def forward(self, inp, hidden_state, cell_state): else: cell = self.cell quant_outputs, quant_hidden_state, quant_cell_state = cell( - quant_input.value, - quant_hidden_state.value, - quant_cell_state.value, - quant_weight_ii=quant_weight_ii.value, - quant_weight_if=quant_weight_if.value, - quant_weight_ic=quant_weight_ic.value, - quant_weight_io=quant_weight_io.value, - quant_weight_hi=quant_weight_hi.value, - quant_weight_hf=quant_weight_hf.value, - quant_weight_hc=quant_weight_hc.value, - quant_weight_ho=quant_weight_ho.value, + quant_input_value, + _unpack_quant_tensor(quant_hidden_state), + _unpack_quant_tensor(quant_cell_state), + quant_weight_ii=_unpack_quant_tensor(quant_weight_ii), + quant_weight_if=_unpack_quant_tensor(quant_weight_if), + quant_weight_ic=_unpack_quant_tensor(quant_weight_ic), + quant_weight_io=_unpack_quant_tensor(quant_weight_io), + quant_weight_hi=_unpack_quant_tensor(quant_weight_hi), + quant_weight_hf=_unpack_quant_tensor(quant_weight_hf), + quant_weight_hc=_unpack_quant_tensor(quant_weight_hc), + quant_weight_ho=_unpack_quant_tensor(quant_weight_ho), quant_bias_input=quant_bias_input, quant_bias_forget=quant_bias_forget, quant_bias_cell=quant_bias_cell, diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5a4b2ed55..f7f120697 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -100,7 +100,7 @@ def forward(self, x: torch.Tensor) -> QuantTensor: out, scale, zero_point, bit_width = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -218,4 +218,4 @@ def forward( raise RuntimeError("Internally defined bit-width required") return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: - return QuantTensor(x, training=self.training) + return x diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9d15f3bba..e67446e3a 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -157,16 +157,20 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] - return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + if isinstance(x, QuantTensor): + return QuantTensor.from_fake_quantized( + y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + else: + return y else: if isinstance(y, tuple): y = y[0] - return QuantTensor(y, training=self.training) + return y else: if isinstance(x, QuantTensor): # passthrough return x else: - return QuantTensor(x, training=self.training) + return x class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index bd1da8edd..9c80f6969 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -3,10 +3,12 @@ from abc import ABC from typing import NamedTuple, Optional +import warnings import torch from torch import Tensor +import brevitas.config as config from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.function.ops_ste import ceil_ste @@ -18,6 +20,10 @@ BFLOAT16_IS_VALID_ATOL = 0.5 +def _get_dequantize_tensor(input): + return input.value if isinstance(input, QuantTensor) else input + + class QuantTensorBase(NamedTuple): value: Tensor scale: Optional[Tensor] @@ -29,7 +35,7 @@ class QuantTensorBase(NamedTuple): def _unpack_quant_tensor(input_data): if isinstance(input_data, QuantTensor): - return input_data.tensor + return input_data.value elif isinstance(input_data, tuple): return tuple([_unpack_quant_tensor(v) for v in input_data]) elif isinstance(input_data, list): @@ -54,7 +60,14 @@ def _is_all_nested_not_none(input_data): class QuantTensor(QuantTensorBase): def __new__( - cls, value, scale=None, zero_point=None, bit_width=None, signed=None, training=None): + cls, + value=None, + scale=None, + zero_point=None, + bit_width=None, + signed=None, + training=None, + _allow_empty=False): if scale is not None and not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -66,7 +79,22 @@ def __new__( signed = torch.tensor(signed, dtype=torch.bool) if training is not None and not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) - return super().__new__(cls, value, scale, zero_point, bit_width, signed, training) + + if _allow_empty: + warnings.warn( + "Empty QuantTensor are deprecated and will be removed in a future version") + # elif value is not None and scale is not None and zero_point is not None: + # is_int = torch.allclose(torch.round(int_value), int_value) + # if not is_int: + # quant_tensor = quant_tensor.set(int_value = torch.round(int_value / scale + zero_point)) + # elif int_value is None and value is not None: + # pass + elif not _allow_empty and (scale is None or bit_width is None or zero_point is None): + raise RuntimeError("To create an emtpy QuantTensor, set _allow_empty=True") + + quant_tensor = super().__new__( + cls, value, scale, zero_point, bit_width, signed, training) + return quant_tensor @property def signed(self): @@ -420,6 +448,9 @@ def __mul__(self, other): def __sub__(self, other): return self.__add__(-other) + def __str__(self): + return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + def __truediv__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 55296ff35..7f2bdcd7d 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -41,8 +41,13 @@ def test_quant_wbiol(model_input, current_cases): is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] - if (not is_input_quanttensor or - kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': + if (not (is_input_quanttensor and kwargs['weight_quant'] is not None) and + kwargs['io_quant'] is None) and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output = model(input) + return + elif (not is_input_quanttensor or + kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': with pytest.raises(RuntimeError, match='Input scale required'): output = model(input) return @@ -199,12 +204,20 @@ def test_quant_mha(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) - - if (kwargs['io_quant'] is None or + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': with pytest.raises(RuntimeError, match='Input scale required'): output, _ = model(inp, inp, inp) return + elif kwargs['io_quant'] is None and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output, _ = model(inp, inp, inp) + return + elif kwargs['io_quant'] is None and kwargs['bias_quant'] == 'quant_external': + with pytest.raises(RuntimeError, match='Input scale required'): + output, _ = model(inp, inp, inp) + return output, _ = model(inp, inp, inp) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index a7f87cbef..e01596dc9 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -116,6 +116,7 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) + computed_out = brevitas_output.value if tolerance is not None and export_type == 'qcdq': tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale @@ -130,7 +131,7 @@ def is_brevitas_ort_close( else: if export_type == 'qop': export_onnx_qop(model, input_t, export_path=export_name) - brevitas_output = brevitas_output.int(float_datatype=False) + computed_out = brevitas_output.int(float_datatype=False) elif export_type == 'qcdq': export_onnx_qcdq(model, input_t, export_path=export_name) elif export_type == 'qcdq_opset14': @@ -145,13 +146,13 @@ def is_brevitas_ort_close( if first_output_only: if isinstance(ort_output, (tuple, list)): ort_output = ort_output[0] - if isinstance(brevitas_output, tuple): - brevitas_output = brevitas_output[0] + if isinstance(computed_out, tuple): + computed_out = computed_out[0] # make sure we are not comparing 0s - if (ort_output == 0).all() and (brevitas_output == 0).all(): + if (ort_output == 0).all() and (computed_out == 0).all(): pytest.skip("Skip testing against all 0s.") - return recursive_allclose(ort_output, brevitas_output, tolerance) + return recursive_allclose(ort_output, computed_out, tolerance) def gen_linspaced_data(num_samples, min_val=-1.0, max_val=1.0): From a781dcc884d98badf888bb922221b0a8a8e0131f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 09:37:42 +0000 Subject: [PATCH 02/32] Fixes --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 552 +++++----- notebooks/02_quant_activation_overview.ipynb | 165 +-- notebooks/03_anatomy_of_a_quantizer.ipynb | 455 ++++---- notebooks/Brevitas_TVMCon2021.ipynb | 975 +++++++++++++++--- src/brevitas/nn/mixin/base.py | 4 +- src/brevitas/nn/quant_layer.py | 7 +- src/brevitas/proxy/runtime_quant.py | 2 +- 7 files changed, 1381 insertions(+), 779 deletions(-) diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index 2e9ef9179..ef7fd28ca 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -18,14 +18,6 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/user/.local/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "data": { "text/markdown": [ @@ -39,14 +31,22 @@ " padding: Union[int, Tuple[int, int]] = 0,\n", " dilation: Union[int, Tuple[int, int]] = 1,\n", " groups: int = 1,\n", + " padding_mode: str = 'zeros',\n", " bias: bool = True,\n", - " padding_type: str = 'standard',\n", " weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,\n", " bias_quant: Optional[BiasQuantType] = None,\n", " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", + " # avoid an init error in the super class by setting padding to 0\n", + " if padding_mode == 'zeros' and padding == 'same' and stride > 1:\n", + " padding = 0\n", + " is_same_padded_strided = True\n", + " else:\n", + " is_same_padded_strided = False\n", " Conv2d.__init__(\n", " self,\n", " in_channels=in_channels,\n", @@ -54,9 +54,12 @@ " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", + " padding_mode=padding_mode,\n", " dilation=dilation,\n", " groups=groups,\n", - " bias=bias)\n", + " bias=bias,\n", + " device=device,\n", + " dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -65,9 +68,7 @@ " output_quant=output_quant,\n", " return_quant_tensor=return_quant_tensor,\n", " **kwargs)\n", - " assert self.padding_mode == 'zeros'\n", - " assert not (padding_type == 'same' and padding != 0)\n", - " self.padding_type = padding_type\n", + " self.is_same_padded_strided = is_same_padded_strided\n", "\n", "```" ], @@ -149,20 +150,28 @@ "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.2594, 0.5392, 0.5916],\n", - " [ 0.3493, 0.6813, 0.2499],\n", - " [ 1.3732, 0.1229, -0.0084]],\n", + "tensor([[[[-0.3189, -0.2848, -0.0037],\n", + " [ 0.2287, 0.7919, -0.2949],\n", + " [ 0.7699, 0.6641, -0.1161]],\n", "\n", - " [[ 0.0031, -0.1702, 0.1069],\n", - " [-0.8181, -0.8056, 0.0385],\n", - " [-0.4738, 0.0589, 0.1278]],\n", + " [[-0.0886, -0.1660, 1.7264],\n", + " [ 0.8113, 0.8065, -0.8843],\n", + " [-0.3388, -0.1821, -0.3209]],\n", "\n", - " [[-0.1718, -0.1162, -0.1526],\n", - " [-0.9903, -0.3541, 0.1645],\n", - " [ 0.0557, -0.4458, -0.2080]]]], grad_fn=)" + " [[ 0.4528, -0.1083, 1.2154],\n", + " [ 1.4329, 1.5554, 1.5001],\n", + " [ 1.0284, 1.4550, 0.5717]]]], grad_fn=)" ] }, "execution_count": 4, @@ -234,31 +243,31 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0790, 0.0503, -0.0934],\n", - " [-0.1149, -0.1903, -0.1329],\n", - " [-0.1813, 0.0108, 0.0593]],\n", + "QuantTensor(value=tensor([[[[-0.0653, 0.0989, 0.0071],\n", + " [-0.1871, -0.0247, -0.0671],\n", + " [ 0.1642, 0.1624, 0.0053]],\n", "\n", - " [[ 0.0970, -0.0215, -0.0144],\n", - " [ 0.2280, 0.1239, -0.0090],\n", - " [ 0.1957, -0.2011, -0.0108]]],\n", + " [[ 0.1306, -0.0335, 0.1448],\n", + " [ 0.1483, -0.0671, -0.2101],\n", + " [ 0.1713, -0.1465, -0.1448]]],\n", "\n", "\n", - " [[[-0.0018, -0.1957, 0.1993],\n", - " [-0.0359, 0.1778, -0.1400],\n", - " [ 0.0916, 0.1059, 0.2173]],\n", + " [[[-0.1448, 0.0600, -0.1201],\n", + " [ 0.1218, -0.1642, 0.1889],\n", + " [ 0.0618, 0.2101, -0.2242]],\n", "\n", - " [[-0.1670, 0.1939, -0.2191],\n", - " [-0.0215, 0.1688, -0.1383],\n", - " [-0.0449, -0.1185, 0.1742]]],\n", + " [[-0.0600, 0.0530, 0.0335],\n", + " [ 0.1201, 0.1571, 0.1254],\n", + " [ 0.1660, 0.0159, -0.0830]]],\n", "\n", "\n", - " [[[-0.0808, -0.1652, -0.0233],\n", - " [-0.0700, 0.0467, -0.0485],\n", - " [ 0.1059, 0.1418, 0.1077]],\n", + " [[[ 0.0106, 0.1536, 0.1730],\n", + " [ 0.1942, 0.0424, 0.2225],\n", + " [ 0.1324, 0.1907, 0.0441]],\n", "\n", - " [[-0.0593, 0.0108, 0.0036],\n", - " [-0.1508, 0.0808, 0.1616],\n", - " [ 0.0144, -0.0287, -0.1365]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.1942, 0.1236, 0.1889],\n", + " [-0.0124, 0.0742, -0.2048],\n", + " [ 0.1271, -0.1607, -0.1924]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 6, @@ -325,15 +334,15 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(0.0173, grad_fn=)\n", - "tensor(0.0307, grad_fn=)\n" + "tensor(0.0211, grad_fn=)\n", + "tensor(0.0162, grad_fn=)\n" ] } ], @@ -361,34 +370,31 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9489, -0.9111, -0.0536, 0.5788, 0.3645],\n", - " [ 0.3401, 1.4325, 0.6498, 0.6411, -1.4390],\n", - " [-1.9029, 0.7012, 0.1591, 1.9235, 0.5883],\n", - " [-2.7258, 2.5330, 0.9165, -0.0820, 3.4148],\n", - " [-0.3651, 1.0164, 0.9567, -0.2758, -1.1376]],\n", - "\n", - " [[-0.2414, 2.2111, -1.9124, -2.3814, -0.8805],\n", - " [ 1.3191, -0.8965, -0.2048, -3.8113, 1.1142],\n", - " [-0.3381, -0.2238, 1.2661, 0.0068, 0.2567],\n", - " [ 0.0731, -0.4280, 0.0909, 0.0875, -1.6851],\n", - " [-0.7744, -1.4127, -0.8143, 1.3557, -0.2802]]]],\n", - " grad_fn=), scale=tensor(0.0240, grad_fn=), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(True))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "QuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n", + " [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n", + " [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n", + " [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n", + " [-0.8979, -0.7092, 3.8232, 1.0875, 0.3954]],\n", + "\n", + " [[ 1.4363, -1.3973, 1.3249, 2.6914, 0.3660],\n", + " [ 1.5057, 1.8094, 0.5100, -1.6874, 1.9981],\n", + " [ 1.2472, -1.7813, 0.0334, -1.2880, -2.9333],\n", + " [ 0.0180, -1.4298, -2.9978, 0.5494, -1.4548],\n", + " [ 1.6738, -0.3177, -0.3721, -0.1650, -1.1871]]]],\n", + " grad_fn=), scale=0.018651068210601807, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=True)\n" + ] } ], "source": [ "out_tensor = out_tensor_0 + out_tensor_1\n", - "out_tensor" + "print(out_tensor)" ] }, { @@ -401,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -417,23 +423,23 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[1.5800, 1.0157],\n", - " [1.4445, 0.8577]],\n", + "QuantTensor(value=tensor([[[[0.5191, 0.6402],\n", + " [2.1455, 0.5883]],\n", "\n", - " [[0.5643, 1.2414],\n", - " [1.0383, 0.9028]],\n", + " [[2.0417, 0.5883],\n", + " [1.2631, 0.3980]],\n", "\n", - " [[0.5191, 0.6546],\n", - " [2.1442, 0.5868]]]], grad_fn=), scale=tensor(0.0226, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[0.7959, 0.5191],\n", + " [0.8132, 1.3496]]]], grad_fn=), scale=tensor(0.0173, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 108, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -455,29 +461,37 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1410693/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " torch.tanh(quant_tensor)\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.4943, -0.9938, -0.9073, 0.7681],\n", - " [-0.3262, 0.9186, 0.1786, 0.3659],\n", - " [ 0.7489, 0.8946, -0.0451, -0.5594],\n", - " [-0.1346, -0.4943, -0.4770, 0.6951]],\n", + "tensor([[[[ 0.4770, 0.2212, 0.0691, 0.5650],\n", + " [-0.0346, -0.6618, -0.4635, -0.3482],\n", + " [ 0.9730, -0.7245, -0.5881, -0.5287],\n", + " [-0.0863, 0.8857, 0.5287, -0.4498]],\n", "\n", - " [[ 0.0676, 0.5111, 0.4943, 0.8459],\n", - " [-0.8990, -0.9426, 0.0676, -0.7945],\n", - " [-0.9220, 0.0676, -0.5594, 0.6321],\n", - " [-0.0676, 0.7772, 0.7177, -0.4414]],\n", + " [[ 0.9669, 0.5650, -0.6211, -0.4498],\n", + " [-0.2376, 0.6103, 0.5287, 0.2700],\n", + " [-0.6808, 0.8519, 0.2700, -0.5531],\n", + " [-0.0173, 0.8264, 0.3782, -0.1881]],\n", "\n", - " [[ 0.4770, 0.2220, 0.0676, 0.5747],\n", - " [-0.0451, -0.6710, -0.4594, -0.3462],\n", - " [ 0.9729, -0.7177, -0.5896, -0.5276],\n", - " [-0.0900, 0.8852, 0.5276, -0.4414]]]], grad_fn=)" + " [[-0.6211, -0.9764, -0.5993, 0.4770],\n", + " [ 0.5033, 0.6618, -0.1881, -0.6211],\n", + " [-0.8031, 0.1375, 0.5287, 0.8740],\n", + " [-0.6714, 0.6714, -0.5650, 0.8611]]]], grad_fn=)" ] }, - "execution_count": 109, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -497,26 +511,26 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9693, -0.9431, 0.2459],\n", - " [ 0.5416, 0.9037, -0.5278],\n", - " [-0.6207, -1.3578, -0.4815]],\n", + "QuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n", + " [-0.4470, 0.1039, -0.3945],\n", + " [-0.4190, 0.3723, 0.8384]],\n", "\n", - " [[ 0.4551, -1.4065, 0.8889],\n", - " [-0.3393, 0.0803, -0.1748],\n", - " [-0.0977, 0.6284, -0.7193]],\n", + " [[-0.0510, 0.5514, -0.2751],\n", + " [-0.5668, 0.5824, 0.2328],\n", + " [ 0.1316, -0.2518, 1.0418]],\n", "\n", - " [[ 0.3655, 0.7626, -0.2634],\n", - " [-0.3453, 0.3349, 0.1923],\n", - " [ 0.5993, -0.9579, 0.3557]]]], grad_fn=), scale=tensor([[[[3.2208e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.2734, 0.7268, -0.0249],\n", + " [-0.1732, 0.5197, 1.1158],\n", + " [ 0.3771, -0.3810, 0.2008]]]], grad_fn=), scale=tensor([[[[3.1958e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 110, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -533,20 +547,9 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -569,26 +572,26 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 5.7000e-03, 2.5000e-03, -1.2400e-02, -7.2000e-03, 3.7000e-03],\n", - " [-2.3000e-03, 7.0000e-04, -1.2700e-02, 5.2000e-03, 4.0000e-04],\n", - " [-7.9000e-03, 9.5000e-03, 6.6000e-03, 5.4000e-03, 2.5000e-03],\n", - " [ 1.1100e-02, 2.4000e-03, 1.0000e-02, -3.7000e-03, 7.2000e-03],\n", - " [-1.1500e-02, -5.8000e-03, -9.3000e-03, 1.0000e-02, 3.5000e-03]],\n", + "QuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],\n", + " [-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],\n", + " [-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],\n", + " [ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],\n", + " [ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],\n", "\n", - " [[-6.8000e-03, 1.1500e-02, -1.0600e-02, -1.5000e-03, -1.9000e-03],\n", - " [ 2.9000e-03, 9.5000e-03, 7.2000e-03, -3.7000e-03, 7.7000e-03],\n", - " [-2.4000e-03, -8.9000e-03, -1.2000e-02, -8.1000e-03, 7.2000e-03],\n", - " [-1.1300e-02, -9.7000e-03, -1.0000e-03, 1.0100e-02, 3.8000e-03],\n", - " [-1.1900e-02, 6.9000e-03, 8.3000e-03, 1.0000e-04, -6.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],\n", + " [-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],\n", + " [-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],\n", + " [-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],\n", + " [-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 112, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -613,20 +616,9 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 17, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert quant_tensor_input.is_valid" ] @@ -642,26 +634,26 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0085, 0.0066, 0.0050],\n", - " [-0.0038, -0.0009, -0.0115],\n", - " [-0.0055, -0.0037, 0.0009]],\n", + "QuantTensor(value=tensor([[[[-0.0019, 0.0049, -0.0012],\n", + " [-0.0012, 0.0050, -0.0074],\n", + " [-0.0023, -0.0035, -0.0033]],\n", "\n", - " [[ 0.0015, -0.0027, -0.0079],\n", - " [-0.0034, -0.0060, 0.0043],\n", - " [-0.0008, 0.0052, -0.0033]],\n", + " [[-0.0031, 0.0028, 0.0116],\n", + " [ 0.0079, 0.0046, 0.0022],\n", + " [ 0.0021, -0.0004, 0.0011]],\n", "\n", - " [[-0.0015, 0.0082, -0.0038],\n", - " [-0.0021, 0.0004, -0.0054],\n", - " [-0.0021, -0.0079, 0.0013]]]], grad_fn=), scale=tensor([[[[1.8448e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0045, -0.0010, 0.0002],\n", + " [-0.0044, 0.0027, 0.0025],\n", + " [-0.0009, 0.0040, -0.0044]]]], grad_fn=), scale=tensor([[[[1.8307e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 114, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -675,20 +667,9 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -702,26 +683,26 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0035, -0.0037, -0.0050],\n", - " [ 0.0010, -0.0051, -0.0027],\n", - " [-0.0010, 0.0047, 0.0017]],\n", + "QuantTensor(value=tensor([[[[-0.0073, 0.0040, -0.0011],\n", + " [-0.0033, 0.0078, -0.0028],\n", + " [ 0.0005, -0.0025, -0.0008]],\n", "\n", - " [[ 0.0021, 0.0002, 0.0027],\n", - " [ 0.0028, 0.0002, -0.0044],\n", - " [ 0.0008, -0.0052, -0.0024]],\n", + " [[ 0.0021, -0.0021, 0.0035],\n", + " [ 0.0012, -0.0016, -0.0023],\n", + " [-0.0010, -0.0015, 0.0040]],\n", "\n", - " [[ 0.0010, -0.0052, -0.0011],\n", - " [-0.0018, 0.0024, 0.0011],\n", - " [-0.0001, 0.0039, 0.0035]]]], grad_fn=), scale=tensor([[[[1.7410e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0010, 0.0047, 0.0025],\n", + " [-0.0014, 0.0021, -0.0039],\n", + " [ 0.0036, -0.0003, 0.0026]]]], grad_fn=), scale=tensor([[[[1.7393e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 116, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -741,26 +722,26 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.2111, 0.4060, 0.3654],\n", - " [-0.7876, 0.8119, -0.9825],\n", - " [-0.5115, 0.3979, -0.3248]],\n", + "QuantTensor(value=tensor([[[[-0.2117, -0.4811, 0.0385],\n", + " [-0.5100, -0.2502, -0.2213],\n", + " [-0.5773, 0.0192, -0.5485]],\n", "\n", - " [[ 0.3816, 0.0568, -0.0812],\n", - " [ 1.0312, -0.7876, 0.8038],\n", - " [-0.3491, -0.4141, 0.0650]],\n", + " [[ 0.1347, 0.8179, -1.2316],\n", + " [-0.6062, 0.4426, -0.3849],\n", + " [ 0.1732, -0.5100, -0.1251]],\n", "\n", - " [[-0.5846, -0.4222, -0.0731],\n", - " [-0.7389, 0.5034, -0.2517],\n", - " [-0.1624, -0.4385, 0.7308]]]], grad_fn=), scale=tensor(0.0081, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 1.0873, 0.2406, -0.2887],\n", + " [-0.4330, -0.4907, -0.2021],\n", + " [ 0.6447, 0.4811, 0.1347]]]], grad_fn=), scale=tensor(0.0096, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 117, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -777,20 +758,9 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 22, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 118, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -816,7 +786,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 23, "metadata": { "tags": [ "raises-exception" @@ -825,18 +795,17 @@ "outputs": [ { "ename": "RuntimeError", - "evalue": "Input scale required", + "evalue": "QuantLayer is not correctly configured", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2280634207.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mbias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb Cell 46\u001b[0m line \u001b[0;36m6\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mquant\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mscaled_int\u001b[39;00m \u001b[39mimport\u001b[39;00m Int8Bias\n\u001b[1;32m 3\u001b[0m bias_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\n\u001b[1;32m 4\u001b[0m in_channels\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, kernel_size\u001b[39m=\u001b[39m(\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 5\u001b[0m bias_quant\u001b[39m=\u001b[39mInt8Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m bias_quant_conv(torch\u001b[39m.\u001b[39;49mrandn(\u001b[39m1\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m5\u001b[39;49m, \u001b[39m5\u001b[39;49m))\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:328\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 324\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 325\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 330\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 331\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } ], @@ -858,26 +827,27 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0005, 0.0043, -0.0004],\n", - " [ 0.0005, 0.0106, 0.0012],\n", - " [ 0.0021, 0.0007, -0.0050]],\n", + "QuantTensor(value=tensor([[[[-2.4238e-03, -5.6598e-03, 5.1882e-03],\n", + " [-6.5582e-03, 8.9274e-03, 4.9640e-04],\n", + " [ 9.6283e-03, -1.7466e-03, -4.8311e-03]],\n", "\n", - " [[-0.0067, -0.0035, -0.0059],\n", - " [-0.0050, -0.0015, -0.0039],\n", - " [ 0.0015, 0.0028, -0.0008]],\n", + " [[ 2.9322e-03, -3.1358e-03, -6.2727e-04],\n", + " [ 2.8722e-06, -3.7981e-03, 1.0973e-02],\n", + " [-4.1031e-03, 6.5909e-03, -4.2369e-03]],\n", "\n", - " [[-0.0051, -0.0050, 0.0060],\n", - " [-0.0015, 0.0037, 0.0071],\n", - " [ 0.0067, 0.0035, -0.0071]]]], grad_fn=), scale=tensor([[[[1.8108e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 4.1967e-03, -7.0733e-03, 1.6456e-03],\n", + " [ 1.8197e-03, -3.1683e-03, 4.8200e-03],\n", + " [-3.2585e-04, 3.1055e-03, 1.9703e-03]]]],\n", + " grad_fn=), scale=tensor([[[[1.7953e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 120, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -895,26 +865,26 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.3825, 0.1371, 0.9135],\n", - " [-0.2016, 0.7495, -0.4071],\n", - " [-0.0755, 0.5283, 0.2388]],\n", + "QuantTensor(value=tensor([[[[-0.2816, -0.5271, -0.1748],\n", + " [-0.4247, -0.1575, 0.0681],\n", + " [ 0.6528, -0.5346, -0.0657]],\n", "\n", - " [[ 0.0788, -0.3802, -0.2234],\n", - " [ 0.8678, -0.5546, 0.4408],\n", - " [-0.6788, 0.4422, 0.3007]],\n", + " [[ 0.2993, -0.3383, 0.3035],\n", + " [-0.4595, -0.6796, -0.9720],\n", + " [-0.1948, -0.5169, -0.2175]],\n", "\n", - " [[ 0.4412, -0.3205, 1.0033],\n", - " [-0.0083, -0.3295, -0.2076],\n", - " [ 0.4417, -0.1046, -0.3493]]]], grad_fn=), scale=tensor([[[[3.8610e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.5586, 0.0665, -0.5807],\n", + " [ 0.5565, 0.1780, -0.0555],\n", + " [-0.1080, 0.0791, -0.2262]]]], grad_fn=), scale=tensor([[[[4.2009e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 121, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -928,26 +898,26 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0036, 0.0024, -0.0033],\n", - " [ 0.0050, 0.0080, -0.0014],\n", - " [-0.0036, -0.0080, -0.0029]],\n", + "QuantTensor(value=tensor([[[[-0.0058, 0.0030, 0.0030],\n", + " [-0.0013, -0.0002, 0.0043],\n", + " [-0.0061, 0.0033, -0.0001]],\n", "\n", - " [[ 0.0083, -0.0093, 0.0048],\n", - " [ 0.0035, 0.0015, -0.0011],\n", - " [-0.0003, 0.0067, 0.0013]],\n", + " [[ 0.0013, -0.0008, -0.0015],\n", + " [ 0.0011, 0.0012, -0.0012],\n", + " [-0.0013, -0.0020, 0.0002]],\n", "\n", - " [[-0.0009, -0.0019, 0.0039],\n", - " [ 0.0010, 0.0056, -0.0037],\n", - " [ 0.0091, -0.0095, 0.0054]]]], grad_fn=), scale=tensor([[[[1.8384e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0061, 0.0053, -0.0004],\n", + " [ 0.0028, 0.0031, -0.0037],\n", + " [ 0.0027, -0.0048, -0.0044]]]], grad_fn=), scale=tensor([[[[1.7370e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 122, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -967,7 +937,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 27, "metadata": { "tags": [ "raises-exception" @@ -981,12 +951,14 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2990591641.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutput_bias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb Cell 53\u001b[0m line \u001b[0;36m4\n\u001b[1;32m 1\u001b[0m output_bias_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\n\u001b[1;32m 2\u001b[0m in_channels\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, kernel_size\u001b[39m=\u001b[39m(\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 3\u001b[0m output_quant\u001b[39m=\u001b[39mInt8ActPerTensorFloat, bias_quant\u001b[39m=\u001b[39mInt8Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 4\u001b[0m output_bias_quant_conv(torch\u001b[39m.\u001b[39;49mrandn(\u001b[39m1\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m5\u001b[39;49m, \u001b[39m5\u001b[39;49m))\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:347\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 344\u001b[0m output_signed \u001b[39m=\u001b[39m quant_input_signed \u001b[39mor\u001b[39;00m quant_weight_signed\n\u001b[1;32m 346\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 347\u001b[0m quant_bias \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias_quant(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, output_scale, output_bit_width)\n\u001b[1;32m 348\u001b[0m quant_bias_value \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(quant_bias, \u001b[39m'\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m'\u001b[39m, quant_bias)\n\u001b[1;32m 349\u001b[0m quant_bias_scale \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(quant_bias, \u001b[39m'\u001b[39m\u001b[39mscale\u001b[39m\u001b[39m'\u001b[39m, \u001b[39mNone\u001b[39;00m)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:206\u001b[0m, in \u001b[0;36mBiasQuantProxyFromInjector.forward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_handler \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_mode \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtensor_quant\n\u001b[1;32m 205\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mand\u001b[39;00m input_scale \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 206\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput scale required\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 207\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_bit_width \u001b[39mand\u001b[39;00m input_bit_width \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput bit-width required\u001b[39m\u001b[39m\"\u001b[39m)\n", "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } @@ -1007,26 +979,26 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.2152, 0.8346, 0.0746],\n", - " [-0.0738, -0.5212, 0.1019],\n", - " [-0.6004, 0.1500, -0.1453]],\n", + "tensor([[[[-0.4360, -0.2674, -0.4194],\n", + " [-0.2412, -0.6360, -0.6838],\n", + " [-0.5227, -0.0199, -0.1445]],\n", "\n", - " [[-1.1551, -1.3458, -0.1312],\n", - " [ 0.2502, -0.5267, 0.2412],\n", - " [-0.3556, -0.3289, -0.2276]],\n", + " [[-0.3524, 0.8025, 0.2844],\n", + " [ 0.9945, -0.4782, 0.8064],\n", + " [ 0.5732, 0.1249, 0.3110]],\n", "\n", - " [[-0.4599, -0.6094, 0.4682],\n", - " [-0.5064, -0.6768, -0.6638],\n", - " [ 0.0066, -0.3581, 0.2359]]]], grad_fn=)" + " [[ 0.3223, 0.2530, 0.2753],\n", + " [ 0.5764, -0.2533, -0.0181],\n", + " [-0.4147, 0.2049, -0.9944]]]], grad_fn=)" ] }, - "execution_count": 124, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1051,30 +1023,30 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.6879, -0.6632, -0.2411],\n", - " [ 0.2064, -0.7371, 0.3910],\n", - " [ 0.9533, 0.2994, 0.6546]],\n", + "QuantTensor(value=tensor([[[[-0.6912, 0.0086, 0.1628],\n", + " [-0.4786, -0.8073, 0.5224],\n", + " [ 0.4157, 0.4686, 0.2560]],\n", "\n", - " [[-0.4684, -0.4495, -0.5021],\n", - " [ 0.5738, 0.4199, -0.3380],\n", - " [ 0.6218, -0.0408, -0.8483]],\n", + " [[ 0.3170, -0.5486, -0.5216],\n", + " [ 0.1832, 1.0217, -0.3637],\n", + " [-0.1115, 0.6974, -0.0452]],\n", "\n", - " [[-0.5625, 0.1837, -1.0575],\n", - " [-1.2816, -0.4993, -0.3409],\n", - " [ 0.4556, -1.4269, 0.5369]]]], grad_fn=), scale=tensor([[[[3.0975e-05]]]], grad_fn=), zero_point=tensor([[[[ 1276.0774]],\n", + " [[-0.6168, -0.5241, -0.6593],\n", + " [ 0.6408, 0.2674, 0.4537],\n", + " [-0.3744, -0.7771, -0.2848]]]], grad_fn=), scale=tensor([[[[3.0094e-05]]]], grad_fn=), zero_point=tensor([[[[ 339.3404]],\n", "\n", - " [[-3152.4585]],\n", + " [[-4597.1797]],\n", "\n", - " [[ 7320.2324]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-3452.3711]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 125, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1089,20 +1061,9 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 30, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -1116,26 +1077,26 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.8357, 0.0733, 0.9527],\n", - " [ 0.1803, 0.2154, 0.7598],\n", - " [ 1.1121, -0.8728, 1.0039]],\n", + "tensor([[[[-0.2327, 0.9267, 0.6294],\n", + " [ 0.0901, 0.1027, -0.0727],\n", + " [-0.5614, 0.6182, 0.5394]],\n", "\n", - " [[ 0.7917, 1.0063, 0.6516],\n", - " [-0.1852, -0.7263, 0.0956],\n", - " [-0.1876, 0.2747, -0.1617]],\n", + " [[ 0.4179, -0.5184, -0.2016],\n", + " [ 0.1390, -0.3925, -0.6171],\n", + " [ 0.4782, 0.0814, 0.6124]],\n", "\n", - " [[ 0.8299, 0.9934, -0.3821],\n", - " [ 0.4865, 0.9309, -0.7924],\n", - " [-0.4201, 0.2343, 0.1532]]]], grad_fn=)" + " [[ 0.2896, -0.3779, 0.9408],\n", + " [-0.1334, 0.6186, 0.2167],\n", + " [-0.5926, 0.3690, -0.0284]]]], grad_fn=)" ] }, - "execution_count": 127, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -1153,6 +1114,11 @@ "source": [ "Altough not obvious, the output is actually implicitly quantized." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { @@ -1171,7 +1137,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.15" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 4d2ac73d1..388f43cea 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -26,14 +26,12 @@ }, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] } ], "source": [ @@ -68,18 +66,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.manual_seed(0)\n", "input_output_quant_conv = QuantConv2d(\n", @@ -178,18 +165,7 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -220,7 +196,7 @@ " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=)" + " grad_fn=)" ] }, "execution_count": 6, @@ -252,7 +228,7 @@ " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 7, @@ -336,13 +312,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=(tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", + "tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", " [0.6257, 0.3567, 0.3611, 0.5474, 0.4810],\n", " [0.3788, 0.1820, 0.4526, 0.6077, 0.7911],\n", " [0.1630, 0.8883, 0.8471, 0.9151, 0.2456],\n", @@ -353,10 +329,10 @@ " [0.3102, 0.2152, 0.3226, 0.2120, 0.4432],\n", " [0.0805, 0.4810, 0.5568, 0.6898, 0.4526],\n", " [0.4106, 0.2284, 0.3480, 0.3878, 0.8723]]]],\n", - " grad_fn=), None, None, None), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))" + " grad_fn=)" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -371,22 +347,14 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "assert not sigmoid_out_tensor.is_valid" + "from brevitas.quant_tensor import QuantTensor\n", + "\n", + "\n", + "assert not isinstance(sigmoid_out_tensor, QuantTensor)" ] }, { @@ -400,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -416,10 +384,10 @@ " [0.6421, 0.0000, 0.0000, 1.1708, 0.4343],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.2266, 0.7931, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" + " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -442,7 +410,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -461,7 +429,7 @@ " [0.0000, 0.0000, 0.4907]]]], grad_fn=)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -482,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -501,7 +469,7 @@ " [0.0000, 0.0000, 0.4839]]]], grad_fn=)" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -535,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -553,20 +521,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_identity(inp1)\n", "out2_train = quant_identity(inp2)\n", @@ -575,20 +532,9 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_identity.eval()\n", "out1_eval = quant_identity(inp1)\n", @@ -605,7 +551,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": { "tags": [ "raises-exception" @@ -617,19 +563,19 @@ "evalue": "'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/02_quant_activation_overview.ipynb Cell 35\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantHardTanh\n\u001b[0;32m----> 3\u001b[0m QuantHardTanh()\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py:96\u001b[0m, in \u001b[0;36mQuantHardTanh.__init__\u001b[0;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 92\u001b[0m act_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m Int8ActPerTensorFloatMinMaxInit,\n\u001b[1;32m 93\u001b[0m input_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 94\u001b[0m return_quant_tensor: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 95\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m---> 96\u001b[0m QuantNLAL\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 97\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 98\u001b[0m act_impl\u001b[39m=\u001b[39;49mnn\u001b[39m.\u001b[39;49mHardtanh,\n\u001b[1;32m 99\u001b[0m passthrough_act\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 100\u001b[0m input_quant\u001b[39m=\u001b[39;49minput_quant,\n\u001b[1;32m 101\u001b[0m act_quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 102\u001b[0m return_quant_tensor\u001b[39m=\u001b[39;49mreturn_quant_tensor,\n\u001b[1;32m 103\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:40\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 39\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 40\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py:118\u001b[0m, in \u001b[0;36mQuantNonLinearActMixin.__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 108\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 109\u001b[0m act_impl: Optional[Type[Module]],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 113\u001b[0m act_kwargs_prefix\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m,\n\u001b[1;32m 114\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m prefixed_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 116\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m: act_impl,\n\u001b[1;32m 117\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mpassthrough_act\u001b[39m\u001b[39m'\u001b[39m: passthrough_act}\n\u001b[0;32m--> 118\u001b[0m QuantProxyMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 120\u001b[0m quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 121\u001b[0m proxy_prefix\u001b[39m=\u001b[39;49mact_proxy_prefix,\n\u001b[1;32m 122\u001b[0m kwargs_prefix\u001b[39m=\u001b[39;49mact_kwargs_prefix,\n\u001b[1;32m 123\u001b[0m proxy_protocol\u001b[39m=\u001b[39;49mActQuantProxyProtocol,\n\u001b[1;32m 124\u001b[0m none_quant_injector\u001b[39m=\u001b[39;49mNoneActQuant,\n\u001b[1;32m 125\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprefixed_kwargs,\n\u001b[1;32m 126\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:71\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 70\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 71\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 72\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 73\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:89\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, quant_layer, quant_injector):\n\u001b[0;32m---> 89\u001b[0m QuantProxyFromInjector\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, quant_layer, quant_injector)\n\u001b[1;32m 90\u001b[0m ActQuantProxyProtocol\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_passthrough_act \u001b[39m=\u001b[39m _is_passthrough_act(quant_injector)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:82\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 82\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_tracked_module(quant_layer)\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdisable_quant \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:120\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.add_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list\u001b[39m.\u001b[39mappend(module)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mupdate_tracked_modules()\n\u001b[0;32m--> 120\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_tensor_quant()\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTrying to add None as a parent module.\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:102\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.init_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minit_tensor_quant\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m--> 102\u001b[0m tensor_quant \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mquant_injector\u001b[39m.\u001b[39;49mtensor_quant\n\u001b[1;32m 103\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector:\n\u001b[1;32m 104\u001b[0m act_impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector\u001b[39m.\u001b[39mact_impl\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py:129\u001b[0m, in \u001b[0;36m_ExtendedInjectorType.__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 127\u001b[0m message \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m can not resolve attribute \u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 128\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, current_attr)\n\u001b[0;32m--> 129\u001b[0m \u001b[39mraise\u001b[39;00m DependencyError(message)\n\u001b[1;32m 131\u001b[0m marker, attribute, args, have_defaults \u001b[39m=\u001b[39m spec\n\u001b[1;32m 133\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mset\u001b[39m(args)\u001b[39m.\u001b[39missubset(cached):\n", + "\u001b[0;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" ] } ], @@ -648,7 +594,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -664,20 +610,9 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_hard_tanh(inp1)\n", "quant_hard_tanh.eval()\n", @@ -711,7 +646,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index 2055a1714..3abd4d4ce 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -181,8 +181,9 @@ " Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.\n", " \"\"\"\n", "\n", - " def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):\n", + " def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):\n", " super(BinaryQuant, self).__init__()\n", + " assert signed, \"Unsigned binary quant not supported\"\n", " self.scaling_impl = scaling_impl\n", " self.bit_width = BitWidthConst(1)\n", " self.zero_point = StatelessBuffer(torch.tensor(0.0))\n", @@ -247,10 +248,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[-0.1000, 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -292,10 +293,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000]], grad_fn=),\n", + "(tensor([[-0.1000, 0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -342,10 +343,10 @@ { "data": { "text/plain": [ - "(tensor([[ 1., -1., 1., 1.],\n", - " [ 1., 1., -1., 1.],\n", - " [ 1., 1., 1., -1.],\n", - " [-1., 1., -1., -1.]], grad_fn=),\n", + "(tensor([[-1., 1., -1., 1.],\n", + " [-1., -1., 1., -1.],\n", + " [-1., -1., 1., 1.],\n", + " [ 1., -1., -1., -1.]], grad_fn=),\n", " tensor(1., grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -380,9 +381,9 @@ "data": { "text/plain": [ "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000]], grad_fn=),\n", + " [ 0.1000, -0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -448,30 +449,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", + "QuantTensor(value=tensor([[[[-0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[ 0.1000, -0.1000, -0.1000],\n", + " [[-0.1000, -0.1000, -0.1000],\n", " [ 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000]],\n", + "\n", + " [[ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, -0.1000, 0.1000],\n", " [-0.1000, -0.1000, 0.1000]]],\n", "\n", "\n", - " [[[ 0.1000, -0.1000, 0.1000],\n", + " [[[ 0.1000, 0.1000, 0.1000],\n", " [ 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + " [-0.1000, -0.1000, 0.1000]],\n", "\n", " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", + " [-0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, 0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" + " [[-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" ] }, "execution_count": 11, @@ -518,30 +519,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", + "QuantTensor(value=tensor([[[[-0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, 0.1000],\n", + " [[ 0.1000, -0.1000, 0.1000],\n", " [ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", + " [[-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, -0.1000]]],\n", "\n", "\n", " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000]],\n", "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", + " [[ 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 13, @@ -560,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -578,39 +579,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],\n", + "QuantTensor(value=tensor([[[[ 0.1000, -0.1000, 0.1000],\n", " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + " [ 0.1000, 0.1000, 0.1000]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", + " [[ 0.1000, 0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000],\n", " [ 0.1000, -0.1000, -0.1000]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", + " [[ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, 0.1000]]],\n", "\n", "\n", - " [[[-0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", - "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", + " [[[-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000],\n", " [ 0.1000, -0.1000, 0.1000]],\n", "\n", " [[ 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, -0.1000]],\n", + "\n", + " [[ 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -640,19 +641,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -678,19 +687,19 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[ 0.0010, 0.0010, 0.0010, -0.0010],\n", - " [ 0.0010, -0.0010, 0.0010, -0.0010],\n", - " [-0.0010, -0.0010, -0.0010, -0.0010],\n", - " [ 0.0010, 0.0010, 0.0010, 0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[-0.0010, -0.0010, -0.0010, 0.0010],\n", + " [ 0.0010, 0.0010, -0.0010, 0.0010],\n", + " [-0.0010, -0.0010, 0.0010, -0.0010],\n", + " [-0.0010, -0.0010, -0.0010, -0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -716,7 +725,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -740,7 +749,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": { "scrolled": true }, @@ -748,33 +757,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + "QuantTensor(value=tensor([[[[-0.1918, 0.1918, 0.1918],\n", + " [ 0.1918, 0.1918, 0.1918],\n", + " [-0.1918, -0.1918, 0.1918]],\n", "\n", - " [[-0.1876, -0.1876, 0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [-0.1876, 0.1876, 0.1876]],\n", + " [[-0.1918, -0.1918, 0.1918],\n", + " [-0.1918, 0.1918, -0.1918],\n", + " [ 0.1918, 0.1918, 0.1918]],\n", "\n", - " [[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, 0.1876, -0.1876]]],\n", + " [[-0.1918, 0.1918, 0.1918],\n", + " [ 0.1918, -0.1918, -0.1918],\n", + " [ 0.1918, 0.1918, 0.1918]]],\n", "\n", "\n", - " [[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876]],\n", + " [[[ 0.1918, -0.1918, 0.1918],\n", + " [-0.1918, -0.1918, 0.1918],\n", + " [ 0.1918, 0.1918, 0.1918]],\n", "\n", - " [[-0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + " [[ 0.1918, 0.1918, 0.1918],\n", + " [ 0.1918, -0.1918, -0.1918],\n", + " [ 0.1918, 0.1918, 0.1918]],\n", "\n", - " [[-0.1876, 0.1876, 0.1876],\n", - " [ 0.1876, -0.1876, 0.1876],\n", - " [-0.1876, -0.1876, -0.1876]]]], grad_fn=), scale=tensor(0.1876, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1918, 0.1918, -0.1918],\n", + " [ 0.1918, -0.1918, 0.1918],\n", + " [ 0.1918, -0.1918, 0.1918]]]], grad_fn=), scale=tensor(0.1918, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 20, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -793,7 +802,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -802,7 +811,7 @@ "True" ] }, - "execution_count": 21, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -820,16 +829,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.1897, grad_fn=)" + "tensor(0.1860, grad_fn=)" ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -850,7 +859,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": { "tags": [ "raises-exception" @@ -862,11 +871,11 @@ "evalue": "Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mparam_from_max_quant_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 46\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m param_from_max_quant_conv\u001b[39m.\u001b[39;49mload_state_dict(float_conv\u001b[39m.\u001b[39;49mstate_dict())\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " ] } ], @@ -916,39 +925,39 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1897, -0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897]],\n", + "QuantTensor(value=tensor([[[[-0.1860, 0.1860, 0.1860],\n", + " [-0.1860, 0.1860, -0.1860],\n", + " [-0.1860, 0.1860, -0.1860]],\n", "\n", - " [[-0.1897, 0.1897, 0.1897],\n", - " [ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, -0.1897, 0.1897]],\n", + " [[ 0.1860, -0.1860, 0.1860],\n", + " [-0.1860, 0.1860, 0.1860],\n", + " [ 0.1860, -0.1860, -0.1860]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, 0.1897]]],\n", + " [[-0.1860, -0.1860, -0.1860],\n", + " [-0.1860, 0.1860, 0.1860],\n", + " [ 0.1860, 0.1860, -0.1860]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[ 0.1860, -0.1860, 0.1860],\n", + " [-0.1860, -0.1860, 0.1860],\n", + " [-0.1860, 0.1860, -0.1860]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[ 0.1860, -0.1860, 0.1860],\n", + " [ 0.1860, -0.1860, -0.1860],\n", + " [ 0.1860, 0.1860, 0.1860]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor(0.1897, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.1860, -0.1860, -0.1860],\n", + " [-0.1860, -0.1860, -0.1860],\n", + " [-0.1860, 0.1860, 0.1860]]]], grad_fn=), scale=tensor(0.1860, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -979,7 +988,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1013,18 +1022,7 @@ "cell_type": "code", "execution_count": 26, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)\n", "quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n", @@ -1036,19 +1034,7 @@ "cell_type": "code", "execution_count": 27, "metadata": {}, - "outputs": [ - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_58415/1066539094.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mquant_conv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mquant_conv2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "assert (quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item()" ] @@ -1065,18 +1051,7 @@ "cell_type": "code", "execution_count": 28, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "class SharedParamFromMeanWeightQuantizer(MySignedBinaryWeightQuantizer):\n", " \n", @@ -1097,7 +1072,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -1140,7 +1115,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1159,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -1260,42 +1235,42 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1842, 0.1842, -0.1842],\n", - " [-0.1842, -0.1842, 0.1842],\n", - " [-0.1842, -0.1842, 0.1842]],\n", + "QuantTensor(value=tensor([[[[ 0.1876, -0.1876, 0.1876],\n", + " [ 0.1876, 0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, -0.1876]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, -0.1842]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, 0.1876],\n", + " [-0.1876, 0.1876, -0.1876]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, 0.1842],\n", - " [-0.1842, 0.1842, -0.1842]]],\n", + " [[ 0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, -0.1876, 0.1876]]],\n", "\n", "\n", - " [[[ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, 0.1838, -0.1838]],\n", + " [[[ 0.1867, 0.1867, -0.1867],\n", + " [-0.1867, 0.1867, 0.1867],\n", + " [-0.1867, -0.1867, 0.1867]],\n", "\n", - " [[ 0.1838, -0.1838, 0.1838],\n", - " [ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, 0.1838, -0.1838]],\n", + " [[-0.1867, -0.1867, -0.1867],\n", + " [-0.1867, 0.1867, 0.1867],\n", + " [ 0.1867, 0.1867, -0.1867]],\n", "\n", - " [[-0.1838, 0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, 0.1838]]]], grad_fn=), scale=tensor([[[[0.1842]]],\n", + " [[-0.1867, -0.1867, 0.1867],\n", + " [ 0.1867, -0.1867, 0.1867],\n", + " [ 0.1867, 0.1867, -0.1867]]]], grad_fn=), scale=tensor([[[[0.1876]]],\n", "\n", "\n", - " [[[0.1838]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1867]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 35, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1318,42 +1293,42 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1875, -0.1875, 0.1875],\n", - " [-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, -0.1875]],\n", + "QuantTensor(value=tensor([[[[-0.1859, 0.1859, 0.1859],\n", + " [-0.1859, 0.1859, -0.1859],\n", + " [-0.1859, 0.1859, -0.1859]],\n", "\n", - " [[-0.1875, 0.1875, 0.1875],\n", - " [ 0.1875, -0.1875, -0.1875],\n", - " [ 0.1875, -0.1875, 0.1875]],\n", + " [[ 0.1859, -0.1859, 0.1859],\n", + " [-0.1859, 0.1859, 0.1859],\n", + " [ 0.1859, -0.1859, -0.1859]],\n", "\n", - " [[-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, 0.1875],\n", - " [-0.1875, 0.1875, 0.1875]]],\n", + " [[-0.1859, -0.1859, -0.1859],\n", + " [-0.1859, 0.1859, 0.1859],\n", + " [ 0.1859, 0.1859, -0.1859]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[ 0.1860, -0.1860, 0.1860],\n", + " [-0.1860, -0.1860, 0.1860],\n", + " [-0.1860, 0.1860, -0.1860]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[ 0.1860, -0.1860, 0.1860],\n", + " [ 0.1860, -0.1860, -0.1860],\n", + " [ 0.1860, 0.1860, 0.1860]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor([[[[0.1875]]],\n", + " [[ 0.1860, -0.1860, -0.1860],\n", + " [-0.1860, -0.1860, -0.1860],\n", + " [-0.1860, 0.1860, 0.1860]]]], grad_fn=), scale=tensor([[[[0.1859]]],\n", "\n", "\n", - " [[[0.1897]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1860]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 36, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1374,19 +1349,19 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[-0.0100, -0.0100, 0.0100, -0.0100],\n", - " [-0.0100, -0.0100, -0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100]], grad_fn=)" + "tensor([[-0.0100, 0.0100, 0.0100, -0.0100],\n", + " [-0.0100, 0.0100, 0.0100, -0.0100],\n", + " [-0.0100, 0.0100, -0.0100, -0.0100],\n", + " [-0.0100, -0.0100, -0.0100, 0.0100]], grad_fn=)" ] }, - "execution_count": 37, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1421,21 +1396,21 @@ "evalue": "'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m quant_identity = QuantIdentity(\n\u001b[1;32m----> 4\u001b[1;33m act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 135\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 136\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 137\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\_dependencies\\this.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, __self__)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\".\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 51\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msymbol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m message = (\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 76\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantIdentity\n\u001b[0;32m----> 3\u001b[0m quant_identity \u001b[39m=\u001b[39m QuantIdentity(\n\u001b[1;32m 4\u001b[0m act_quant\u001b[39m=\u001b[39;49mAdvancedActQuantizer, is_clamped\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, scaling_per_output_channel\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py:113\u001b[0m, in \u001b[0;36mQuantIdentity.__init__\u001b[0;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 109\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 110\u001b[0m act_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m Int8ActPerTensorFloat,\n\u001b[1;32m 111\u001b[0m return_quant_tensor: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 112\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 113\u001b[0m QuantNLAL\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 114\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 115\u001b[0m input_quant\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 116\u001b[0m act_impl\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 117\u001b[0m passthrough_act\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 118\u001b[0m act_quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 119\u001b[0m return_quant_tensor\u001b[39m=\u001b[39;49mreturn_quant_tensor,\n\u001b[1;32m 120\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:40\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 39\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 40\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py:118\u001b[0m, in \u001b[0;36mQuantNonLinearActMixin.__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 108\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 109\u001b[0m act_impl: Optional[Type[Module]],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 113\u001b[0m act_kwargs_prefix\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m,\n\u001b[1;32m 114\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m prefixed_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 116\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m: act_impl,\n\u001b[1;32m 117\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mpassthrough_act\u001b[39m\u001b[39m'\u001b[39m: passthrough_act}\n\u001b[0;32m--> 118\u001b[0m QuantProxyMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 120\u001b[0m quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 121\u001b[0m proxy_prefix\u001b[39m=\u001b[39;49mact_proxy_prefix,\n\u001b[1;32m 122\u001b[0m kwargs_prefix\u001b[39m=\u001b[39;49mact_kwargs_prefix,\n\u001b[1;32m 123\u001b[0m proxy_protocol\u001b[39m=\u001b[39;49mActQuantProxyProtocol,\n\u001b[1;32m 124\u001b[0m none_quant_injector\u001b[39m=\u001b[39;49mNoneActQuant,\n\u001b[1;32m 125\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprefixed_kwargs,\n\u001b[1;32m 126\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:71\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 70\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 71\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 72\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 73\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:89\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, quant_layer, quant_injector):\n\u001b[0;32m---> 89\u001b[0m QuantProxyFromInjector\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, quant_layer, quant_injector)\n\u001b[1;32m 90\u001b[0m ActQuantProxyProtocol\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_passthrough_act \u001b[39m=\u001b[39m _is_passthrough_act(quant_injector)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:82\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 82\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_tracked_module(quant_layer)\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdisable_quant \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:120\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.add_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list\u001b[39m.\u001b[39mappend(module)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mupdate_tracked_modules()\n\u001b[0;32m--> 120\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_tensor_quant()\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTrying to add None as a parent module.\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:102\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.init_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minit_tensor_quant\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m--> 102\u001b[0m tensor_quant \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mquant_injector\u001b[39m.\u001b[39;49mtensor_quant\n\u001b[1;32m 103\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector:\n\u001b[1;32m 104\u001b[0m act_impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector\u001b[39m.\u001b[39mact_impl\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/_dependencies/this.py:51\u001b[0m, in \u001b[0;36m_ThisSpec.__call__\u001b[0;34m(self, __self__)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[39mif\u001b[39;00m kind \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39;49m(result, symbol)\n\u001b[1;32m 52\u001b[0m \u001b[39mexcept\u001b[39;00m DependencyError:\n\u001b[1;32m 53\u001b[0m message \u001b[39m=\u001b[39m (\n\u001b[1;32m 54\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mYou tried to shift this more times than Injector has levels\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 55\u001b[0m )\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py:129\u001b[0m, in \u001b[0;36m_ExtendedInjectorType.__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 127\u001b[0m message \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m can not resolve attribute \u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 128\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, current_attr)\n\u001b[0;32m--> 129\u001b[0m \u001b[39mraise\u001b[39;00m DependencyError(message)\n\u001b[1;32m 131\u001b[0m marker, attribute, args, have_defaults \u001b[39m=\u001b[39m spec\n\u001b[1;32m 133\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mset\u001b[39m(args)\u001b[39m.\u001b[39missubset(cached):\n", + "\u001b[0;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" ] } ], @@ -1455,22 +1430,22 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", - " [ 0.0100, -0.0100, -0.0100, -0.0100]], grad_fn=), scale=tensor([[0.0100],\n", + "QuantTensor(value=tensor([[-0.0100, 0.0100, 0.0100, 0.0100],\n", + " [-0.0100, 0.0100, -0.0100, 0.0100],\n", + " [-0.0100, -0.0100, -0.0100, -0.0100],\n", + " [ 0.0100, 0.0100, 0.0100, -0.0100]], grad_fn=), scale=tensor([[0.0100],\n", " [0.0100],\n", " [0.0100],\n", " [0.0100]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 39, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1506,7 +1481,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index efd9421f0..2a11e50cc 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -45,8 +45,10 @@ " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", - " Linear.__init__(self, in_features, out_features, bias)\n", + " Linear.__init__(self, in_features, out_features, bias, device=device, dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -118,7 +120,7 @@ " QuantTensor(value=tensor([[-0.0046, 0.3803],\n", " [-0.5820, -0.5224],\n", " [-0.2704, 0.1879],\n", - " [-0.0137, 0.5591]], grad_fn=), scale=tensor(0.0046, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [-0.0137, 0.5591]], grad_fn=), scale=0.004582525696605444, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n" ] } @@ -161,7 +163,7 @@ " tensor([[ -1, 83],\n", " [-127, -114],\n", " [ -59, 41],\n", - " [ -3, 122]], dtype=torch.int32)\n" + " [ -3, 122]], dtype=torch.int8)\n" ] } ], @@ -194,7 +196,15 @@ "Float output:\n", " tensor([[-0.9036, -0.4586, 0.3096, -0.6472],\n", " [ 1.2058, 0.6525, -0.3723, 0.8677],\n", - " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" ] } ], @@ -238,7 +248,7 @@ " QuantTensor(value=tensor([[-0.0078, 0.3828],\n", " [-0.5781, -0.5234],\n", " [-0.2734, 0.1875],\n", - " [-0.0156, 0.5625]], grad_fn=), scale=tensor(0.0078, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " [-0.0156, 0.5625]], grad_fn=), scale=0.0078125, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", "Weight fix point: 7.0\n" ] } @@ -277,7 +287,7 @@ " QuantTensor(value=tensor([[-0.1000, 0.1000],\n", " [-0.1000, -0.1000],\n", " [-0.1000, 0.1000],\n", - " [-0.1000, 0.1000]], grad_fn=), scale=tensor(0.1000), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-0.1000, 0.1000]], grad_fn=), scale=0.10000000149011612, zero_point=0.0, bit_width=1.0, signed_t=True, training_t=True)\n" ] } ], @@ -372,7 +382,7 @@ "Quant output:\n", " tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" ] } ], @@ -409,7 +419,7 @@ "Quant output:\n", " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -452,12 +462,12 @@ "Quant input:\n", " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -511,7 +521,7 @@ "Quant output:\n", " QuantTensor(value=tensor([[1.5410, 0.0000],\n", " [0.0000, 0.5681],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " [0.0000, 0.0000]], grad_fn=), scale=0.006043121684342623, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] } ], @@ -555,11 +565,11 @@ "Quant output after QuantIdentity:\n", " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", "Quant output after QuantReLU:\n", " QuantTensor(value=tensor([[1.5490, 0.0000],\n", " [0.0000, 0.5588],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0061, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " [0.0000, 0.0000]], grad_fn=), scale=0.006074443459510803, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] } ], @@ -602,7 +612,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": { "tags": [ "raises-exception" @@ -611,18 +621,17 @@ "outputs": [ { "ename": "RuntimeError", - "evalue": "Input scale required", + "evalue": "QuantLayer is not correctly configured", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/2660651517.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mquant_linear\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQuantLinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mInt16Bias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mquant_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 97\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 98\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 99\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[1;34m(self, inp)\u001b[0m\n\u001b[0;32m 355\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 356\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 357\u001b[1;33m \u001b[0mquant_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 358\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 359\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\proxy\\parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[0;32m 194\u001b[0m \u001b[0mimpl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 195\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 196\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input scale required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 197\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input bit-width required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Input scale required" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 35\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m float_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias_quant\u001b[39m=\u001b[39mInt16Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[39m=\u001b[39m quant_linear(float_input)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:328\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 324\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 325\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 330\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 331\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } ], @@ -646,7 +655,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -654,10 +663,10 @@ "text/plain": [ "QuantTensor(value=tensor([[-0.6541, 0.1263, 0.1680, -0.1231],\n", " [ 1.4658, 1.2395, -0.5207, 1.3989],\n", - " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" + " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -703,7 +712,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -713,15 +722,15 @@ "Eval mode add quant inputs:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", " QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode add quant output:\n", " QuantTensor(value=tensor([[ 1.9329, 0.5431],\n", " [-2.7636, 0.1757],\n", - " [-1.6773, -1.2300]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(False))\n" + " [-1.6773, -1.2300]]), scale=0.015974320471286774, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=False)\n" ] } ], @@ -769,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -784,7 +793,7 @@ " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", "\n", " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " QuantTensor(value=tensor([[[-1.1218, -0.2533],\n", @@ -794,15 +803,7 @@ " [ 0.8685, -0.1086]],\n", "\n", " [[ 1.2666, 2.0084],\n", - " [ 0.6152, -0.8323]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\Alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\functional.py:652: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\\c10/core/TensorImpl.h:1156.)\n", - " return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" + " [ 0.6152, -0.8323]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n" ] } ], @@ -830,7 +831,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -845,7 +846,7 @@ " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", "\n", " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " tensor([[[-0.8082, -0.8204, -0.2480, -0.4089],\n", @@ -855,7 +856,15 @@ " [ 0.1614, 0.7006, -0.1438, -0.1081]],\n", "\n", " [[ 0.7272, 0.8529, 0.9646, 0.0542],\n", - " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1718099/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " quant_output = torch.tanh(quant_input)\n" ] } ], @@ -883,7 +892,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -893,14 +902,24 @@ "Eval mode concat quant inputs:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", + " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode concat quant output:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875, 0.3994, 0.8307],\n", " [-2.0447, 0.5751, -0.7188, -0.3994],\n", - " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False))\n" + " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1718099/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " train_mode_cat = torch.cat([quant_identity(float_inp1), quant_identity(float_inp2)], dim=1)\n", + "/tmp/ipykernel_1718099/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " eval_mode_cat = torch.cat([eval_quant_inp1, eval_quant_inp2], dim=1)\n" ] } ], @@ -946,7 +965,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -957,7 +976,7 @@ " QuantTensor(value=tensor([[-0.0000, 0.3880],\n", " [-0.5820, -0.5044],\n", " [-0.2716, 0.1940],\n", - " [-0.0000, 0.5432]], grad_fn=), scale=tensor(0.0388, grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-0.0000, 0.5432]], grad_fn=), scale=0.03879871591925621, zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" ] } ], @@ -980,7 +999,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -994,7 +1013,7 @@ " [-0.0000, 0.5607]], grad_fn=), scale=tensor([[0.0253],\n", " [0.0388],\n", " [0.0182],\n", - " [0.0374]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" + " [0.0374]], grad_fn=), zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" ] } ], @@ -1017,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1027,7 +1046,7 @@ "QuantTensor:\n", " QuantTensor(value=tensor([[ 1.6341, -0.5447],\n", " [-2.1788, 0.5447],\n", - " [-1.0894, -1.6341]], grad_fn=), scale=tensor(0.5447, grad_fn=), zero_point=tensor(0.), bit_width=tensor(3.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.0894, -1.6341]], grad_fn=), scale=0.5446973443031311, zero_point=0.0, bit_width=3.0, signed_t=True, training_t=True)\n" ] } ], @@ -1050,7 +1069,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1061,7 +1080,7 @@ " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0235, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, - "execution_count": 22, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1087,7 +1106,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1100,8 +1119,8 @@ "\n", "Per-channel quant output:\n", " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -1145,7 +1164,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1158,8 +1177,8 @@ "\n", "Per-channel quant output:\n", " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -1219,7 +1238,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1253,7 +1272,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1293,7 +1312,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1307,7 +1326,7 @@ " [-0.0132, 0.5607]], grad_fn=), scale=tensor([[0.0030],\n", " [0.0046],\n", " [0.0021],\n", - " [0.0044]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(8., grad_fn=), signed_t=tensor(True), training_t=tensor(True))\n" + " [0.0044]], grad_fn=), zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n" ] } ], @@ -1337,7 +1356,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1345,11 +1364,11 @@ "text/plain": [ "QuantTensor(value=tensor([[-0.9109, -0.4588, 0.3119, -0.6530],\n", " [ 1.2089, 0.6493, -0.3731, 0.8706],\n", - " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", - " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" + " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", + " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 28, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1394,7 +1413,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": { "tags": [ "raises-exception" @@ -1406,11 +1425,11 @@ "evalue": "Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/1653109852.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m return_quant_tensor=True, bias=False)\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mquant_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 75\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 5\u001b[0m float_linear \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[39m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[39m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m quant_linear\u001b[39m.\u001b[39;49mload_state_dict(float_linear\u001b[39m.\u001b[39;49mstate_dict())\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] } ], @@ -1440,7 +1459,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1449,7 +1468,7 @@ "" ] }, - "execution_count": 30, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -1481,7 +1500,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1548,7 +1567,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -1575,10 +1594,12 @@ " (stats): _Stats(\n", " (stats_impl): AbsPercentile()\n", " )\n", - " (restrict_clamp_scaling): _RestrictClampValue(\n", - " (clamp_min_ste): Identity()\n", + " (restrict_scaling): _RestrictValue(\n", " (restrict_value_impl): FloatRestrictValue()\n", " )\n", + " (clamp_scaling): _ClampValue(\n", + " (clamp_min_ste): ScalarClampMinSte()\n", + " )\n", " (restrict_inplace_preprocess): Identity()\n", " (restrict_preprocess): Identity()\n", " )\n", @@ -1595,7 +1616,7 @@ ")" ] }, - "execution_count": 32, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1617,7 +1638,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -1669,7 +1690,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1729,7 +1750,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1785,7 +1806,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1794,7 +1815,7 @@ "True" ] }, - "execution_count": 36, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1845,20 +1866,21 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: netron in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (5.3.9)\n", - "Requirement already satisfied: onnx in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (1.10.2)\n", - "Requirement already satisfied: onnxoptimizer in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (0.2.6)\n", - "Requirement already satisfied: numpy>=1.16.6 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.21.2)\n", - "Requirement already satisfied: typing-extensions>=3.6.2.1 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.10.0.2)\n", - "Requirement already satisfied: protobuf in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.19.1)\n", - "Requirement already satisfied: six in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.16.0)\n" + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: netron in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (7.2.9)\n", + "Requirement already satisfied: onnx in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (1.15.0)\n", + "Requirement already satisfied: onnxoptimizer in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (0.3.13)\n", + "Requirement already satisfied: numpy in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (from onnx) (1.26.0)\n", + "Requirement already satisfied: protobuf>=3.20.2 in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (from onnx) (3.20.3)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m" ] } ], @@ -1868,7 +1890,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -1894,9 +1916,202 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" + ] + }, + { + "data": { + "text/plain": [ + "ir_version: 7\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " output: \"/export_handler/Constant_output_0\"\n", + " name: \"/export_handler/Constant\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 1\n", + " raw_data: \"\\000\\000\\000<\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_1_output_0\"\n", + " name: \"/export_handler/Constant_1\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 3\n", + " raw_data: \"\\000\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " input: \"inp.1\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " output: \"/export_handler/QuantizeLinear_output_0\"\n", + " name: \"/export_handler/QuantizeLinear\"\n", + " op_type: \"QuantizeLinear\"\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_2_output_0\"\n", + " name: \"/export_handler/Constant_2\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 3\n", + " raw_data: \"\\003\\006\\376\\006\\377\\001\\007\\371\\373\\376\\375\\006\\373\\375\\373\\371\\374\\006\\003\\004\\000\\374\\001\\371\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_3_output_0\"\n", + " name: \"/export_handler/Constant_3\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 1\n", + " raw_data: \"\\242\\272_=\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_4_output_0\"\n", + " name: \"/export_handler/Constant_4\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " dims: 4\n", + " data_type: 6\n", + " raw_data: \"M\\375\\377\\377\\023\\376\\377\\377\\\\\\002\\000\\0001\\002\\000\\000\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " input: \"/export_handler/QuantizeLinear_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_2_output_0\"\n", + " input: \"/export_handler/Constant_3_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_4_output_0\"\n", + " output: \"/export_handler/QLinearConv_output_0\"\n", + " name: \"/export_handler/QLinearConv\"\n", + " op_type: \"QLinearConv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " node {\n", + " input: \"/export_handler/QLinearConv_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " output: \"10\"\n", + " name: \"/export_handler/DequantizeLinear\"\n", + " op_type: \"DequantizeLinear\"\n", + " }\n", + " name: \"main_graph\"\n", + " input {\n", + " name: \"inp.1\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"10\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 13\n", + "}" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -1918,7 +2133,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -1980,9 +2195,319 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ir_version: 8\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " input: \"x.87\"\n", + " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_output_0\"\n", + " output: \"/input_quant/export_handler/Quant_output_0\"\n", + " name: \"/input_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_output_0\"\n", + " output: \"/weight_quant/export_handler/Quant_output_0\"\n", + " name: \"/weight_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"bias\"\n", + " input: \"onnx.brevitas::Quant_11\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/bias_quant/export_handler/Constant_output_0\"\n", + " output: \"/bias_quant/export_handler/Quant_output_0\"\n", + " name: \"/bias_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"/input_quant/export_handler/Quant_output_0\"\n", + " input: \"/weight_quant/export_handler/Quant_output_0\"\n", + " input: \"/bias_quant/export_handler/Quant_output_0\"\n", + " output: \"/Conv_output_0\"\n", + " name: \"/Conv\"\n", + " op_type: \"Conv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " node {\n", + " input: \"/Conv_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_output_0\"\n", + " output: \"15\"\n", + " name: \"/output_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " name: \"main_graph\"\n", + " initializer {\n", + " dims: 4\n", + " data_type: 1\n", + " name: \"bias\"\n", + " raw_data: \"w\\010\\227\\276\\360\\203W\\276q\\341\\203>\\002\\034u>\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\000A\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\000\\000\\000<\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\000\\000\\000\\000\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200@\"\n", + " }\n", + " initializer {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\372\\313\\'>\\372\\313\\247>\\242\\272\\337\\275\\372\\313\\247>\\242\\272_\\275\\242\\272_=N\\303\\303>N\\303\\303\\276\\245\\324\\213\\276\\242\\272\\337\\275\\372\\313\\'\\276\\372\\313\\247>\\245\\324\\213\\276\\372\\313\\'\\276\\245\\324\\213\\276N\\303\\303\\276\\242\\272_\\276\\372\\313\\247>\\372\\313\\'>\\242\\272_>\\000\\000\\000\\000\\242\\272_\\276\\242\\272_=N\\303\\303\\276\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\242\\272_=\"\n", + " }\n", + " initializer {\n", + " dims: 1\n", + " data_type: 1\n", + " name: \"onnx.brevitas::Quant_11\"\n", + " raw_data: \"\\242\\272\\3379\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/bias_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200A\"\n", + " }\n", + " input {\n", + " name: \"x.87\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " input {\n", + " name: \"bias\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"15\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/input_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/weight_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/bias_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 17\n", + "}\n", + "opset_import {\n", + " domain: \"onnx.brevitas\"\n", + " version: 1\n", + "}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -2003,7 +2528,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -2053,9 +2578,191 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ir_version: 8\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " input: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_3_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_output_0\"\n", + " output: \"/weight_quant/export_handler/Quant_output_0\"\n", + " name: \"/weight_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"x.27\"\n", + " input: \"/weight_quant/export_handler/Quant_output_0\"\n", + " input: \"bias\"\n", + " output: \"8\"\n", + " name: \"/Conv\"\n", + " op_type: \"Conv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " name: \"main_graph\"\n", + " initializer {\n", + " dims: 4\n", + " data_type: 1\n", + " name: \"bias\"\n", + " raw_data: \"\\243\\303\\206\\275\\325\\3600=\\366C\\275>\\222\\347\\301\\276\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200@\"\n", + " }\n", + " initializer {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\000\\000\\000\\200\\2227d>\\256)\\253\\276\\273\\242\\216\\276\\256)+\\276\\2227\\344=\\000\\000\\000\\200\\256)\\253>\\2227d\\275\\2227\\344=\\2227\\344\\275\\2227d\\275\\240\\260\\307\\276\\273\\242\\216\\276\\256)+\\276\\000\\000\\000\\000\\256)+>\\2227d>\\273\\242\\216\\276\\256)+\\276\\256)+>\\256)\\253>\\2227\\344\\275\\273\\242\\216>\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\2227d=\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_3_output_0\"\n", + " raw_data: \"\\000\\000\\000\\000\"\n", + " }\n", + " input {\n", + " name: \"x.27\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " input {\n", + " name: \"bias\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"8\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/weight_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 17\n", + "}\n", + "opset_import {\n", + " domain: \"onnx.brevitas\"\n", + " version: 1\n", + "}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -2067,7 +2774,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 43, "metadata": { "tags": [ "skip-execution" @@ -2096,10 +2803,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 41, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -2121,9 +2828,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "RecursiveScriptModule(original_name=_JitTraceExportWrapper)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from brevitas.quant import ShiftedUint8ActPerTensorFloat\n", "from brevitas.export import export_torch_qop\n", @@ -2142,21 +2860,13 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 45, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\quant_tensor\\__init__.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " training = torch.tensor(training, dtype=torch.bool)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -2179,10 +2889,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 42, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } @@ -2239,9 +2949,24 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 46, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'\n", + " torch.has_cuda,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'\n", + " torch.has_cudnn,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n", + " torch.has_mps,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'\n", + " torch.has_mkldnn,\n" + ] + } + ], "source": [ "from brevitas.graph.calibrate import bias_correction_mode\n", "from brevitas.graph.calibrate import calibration_mode\n", @@ -2280,7 +3005,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:42:03) [MSC v.1929 64 bit (AMD64)]" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 12a252398..1410ebeb0 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -175,13 +175,13 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): - inp = inp.set(qt_value=inp.qt_value.rename(None)) + inp = inp.set(value=inp.value.rename(None)) else: inp = inp.rename(None) return inp def pack_output(self, quant_output: QuantTensor): - if not self.training and self.cache_inference_quant_out: + if not self.training and self.cache_inference_quant_out and isinstance(quant_output, QuantTensor): self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index f20b93f65..9897d0851 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -139,7 +139,8 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(quant_input.value) + quant_input_value = getattr(quant_input, 'value', quant_input) + out = self.export_handler(quant_input_value) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -348,7 +349,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe quant_bias_value = getattr(quant_bias, 'value', quant_bias) quant_bias_scale = getattr(quant_bias, 'scale', None) quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None) - if not self.training and self.cache_inference_quant_bias: + if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( return_value(quant_input), return_value(quant_weight), return_value(quant_bias)) @@ -384,7 +385,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if not self.return_quant_tensor or not compute_output_quant_tensor: quant_output = output_tensor else: - quant_output = QuantTensor.from_fake_quantized( + quant_output = QuantTensor( output_tensor, scale=output_scale, zero_point=output_zero_point, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index e67446e3a..b2b10c08a 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -158,7 +158,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: if isinstance(y, tuple): y = y[0] if isinstance(x, QuantTensor): - return QuantTensor.from_fake_quantized( + return QuantTensor( y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: return y From 979d5453e93c86d078c4151c455bf3675d99082e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 09:53:23 +0000 Subject: [PATCH 03/32] More fix --- src/brevitas/nn/quant_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 9897d0851..ee5b8e628 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -139,8 +139,8 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - quant_input_value = getattr(quant_input, 'value', quant_input) - out = self.export_handler(quant_input_value) + # quant_input_value = getattr(quant_input, 'value', quant_input) + out = self.export_handler(quant_input) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) From 601898e35920cf579ccab094facdc81e5436046d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 10:24:08 +0000 Subject: [PATCH 04/32] Fix recurrent --- notebooks/Brevitas_TVMCon2021.ipynb | 135 +++--- notebooks/ONNX_export_tutorial.ipynb | 89 ++-- notebooks/quantized_recurrent.ipynb | 639 +++++++++++++-------------- src/brevitas/nn/mixin/base.py | 7 +- 4 files changed, 423 insertions(+), 447 deletions(-) diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 2a11e50cc..e39b7301d 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -612,7 +612,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": { "tags": [ "raises-exception" @@ -626,11 +626,11 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 35\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m float_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias_quant\u001b[39m=\u001b[39mInt16Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[39m=\u001b[39m quant_linear(float_input)\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 35\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m float_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias_quant\u001b[39m=\u001b[39mInt16Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[39m=\u001b[39m quant_linear(float_input)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:328\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 324\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 325\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 330\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 331\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:329\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 325\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 326\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 327\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 328\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 329\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 333\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 334\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } @@ -655,7 +655,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -666,7 +666,7 @@ " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 15, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -712,7 +712,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -778,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -831,7 +831,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -863,7 +863,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_1718099/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_1735865/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " quant_output = torch.tanh(quant_input)\n" ] } @@ -892,7 +892,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -916,9 +916,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_1718099/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_1735865/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " train_mode_cat = torch.cat([quant_identity(float_inp1), quant_identity(float_inp2)], dim=1)\n", - "/tmp/ipykernel_1718099/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_1735865/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " eval_mode_cat = torch.cat([eval_quant_inp1, eval_quant_inp2], dim=1)\n" ] } @@ -965,7 +965,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -999,7 +999,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1036,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -1069,7 +1069,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1080,7 +1080,7 @@ " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0235, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, - "execution_count": 23, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1106,7 +1106,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1164,7 +1164,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1238,7 +1238,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1272,7 +1272,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1312,7 +1312,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1356,7 +1356,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1368,7 +1368,7 @@ " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1413,7 +1413,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": { "tags": [ "raises-exception" @@ -1427,7 +1427,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 75\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 5\u001b[0m float_linear \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[39m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[39m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m quant_linear\u001b[39m.\u001b[39;49mload_state_dict(float_linear\u001b[39m.\u001b[39;49mstate_dict())\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 75\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 5\u001b[0m float_linear \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[39m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[39m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m quant_linear\u001b[39m.\u001b[39;49mload_state_dict(float_linear\u001b[39m.\u001b[39;49mstate_dict())\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] @@ -1459,7 +1459,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -1468,7 +1468,7 @@ "" ] }, - "execution_count": 31, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1500,7 +1500,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1567,7 +1567,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1616,7 +1616,7 @@ ")" ] }, - "execution_count": 33, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1638,7 +1638,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1690,7 +1690,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -1750,7 +1750,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1806,7 +1806,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -1815,7 +1815,7 @@ "True" ] }, - "execution_count": 37, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1866,7 +1866,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1890,7 +1890,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -1916,7 +1916,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -2107,7 +2107,7 @@ "}" ] }, - "execution_count": 40, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -2133,7 +2133,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": { "tags": [ "skip-execution" @@ -2141,33 +2141,22 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'qop_onnx_conv_4b8b.onnx' at http://localhost:8082\n" + "ename": "OSError", + "evalue": "[Errno 98] Address already in use", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 103\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m show_netron(output_path, \u001b[39m8082\u001b[39;49m)\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 103\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mshow_netron\u001b[39m(model_path, port):\n\u001b[1;32m 6\u001b[0m time\u001b[39m.\u001b[39msleep(\u001b[39m3.\u001b[39m)\n\u001b[0;32m----> 7\u001b[0m netron\u001b[39m.\u001b[39;49mstart(model_path, address\u001b[39m=\u001b[39;49m(\u001b[39m\"\u001b[39;49m\u001b[39mlocalhost\u001b[39;49m\u001b[39m\"\u001b[39;49m, port), browse\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 8\u001b[0m \u001b[39mreturn\u001b[39;00m IFrame(src\u001b[39m=\u001b[39m\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttp://localhost:\u001b[39m\u001b[39m{\u001b[39;00mport\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m\"\u001b[39m, width\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m100\u001b[39m\u001b[39m%\u001b[39m\u001b[39m\"\u001b[39m, height\u001b[39m=\u001b[39m\u001b[39m400\u001b[39m)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:321\u001b[0m, in \u001b[0;36mstart\u001b[0;34m(file, address, browse, verbosity)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstart\u001b[39m(file\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, address\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, browse\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, verbosity\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 310\u001b[0m \u001b[39m \u001b[39m\u001b[39m'''Start serving model file at address and open in web browser.\u001b[39;00m\n\u001b[1;32m 311\u001b[0m \n\u001b[1;32m 312\u001b[0m \u001b[39m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[39m A (host, port) address tuple.\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[39m '''\u001b[39;00m\n\u001b[0;32m--> 321\u001b[0m \u001b[39mreturn\u001b[39;00m serve(file, \u001b[39mNone\u001b[39;49;00m, browse\u001b[39m=\u001b[39;49mbrowse, address\u001b[39m=\u001b[39;49maddress, verbosity\u001b[39m=\u001b[39;49mverbosity)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:298\u001b[0m, in \u001b[0;36mserve\u001b[0;34m(file, data, address, browse, verbosity)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 296\u001b[0m address \u001b[39m=\u001b[39m _make_port(address)\n\u001b[0;32m--> 298\u001b[0m thread \u001b[39m=\u001b[39m _HTTPServerThread(content, address, verbosity)\n\u001b[1;32m 299\u001b[0m thread\u001b[39m.\u001b[39mstart()\n\u001b[1;32m 300\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m thread\u001b[39m.\u001b[39malive():\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:129\u001b[0m, in \u001b[0;36m_HTTPServerThread.__init__\u001b[0;34m(self, content, address, verbosity)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maddress \u001b[39m=\u001b[39m address\n\u001b[1;32m 128\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39murl \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mhttp://\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m address[\u001b[39m0\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m:\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(address[\u001b[39m1\u001b[39m])\n\u001b[0;32m--> 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver \u001b[39m=\u001b[39m _ThreadedHTTPServer(address, _HTTPRequestHandler)\n\u001b[1;32m 130\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver\u001b[39m.\u001b[39mtimeout \u001b[39m=\u001b[39m \u001b[39m0.25\u001b[39m\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver\u001b[39m.\u001b[39mblock_on_close \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/socketserver.py:456\u001b[0m, in \u001b[0;36mTCPServer.__init__\u001b[0;34m(self, server_address, RequestHandlerClass, bind_and_activate)\u001b[0m\n\u001b[1;32m 454\u001b[0m \u001b[39mif\u001b[39;00m bind_and_activate:\n\u001b[1;32m 455\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 456\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mserver_bind()\n\u001b[1;32m 457\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_activate()\n\u001b[1;32m 458\u001b[0m \u001b[39mexcept\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/http/server.py:136\u001b[0m, in \u001b[0;36mHTTPServer.server_bind\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mserver_bind\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 135\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Override server_bind to store the server name.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 136\u001b[0m socketserver\u001b[39m.\u001b[39;49mTCPServer\u001b[39m.\u001b[39;49mserver_bind(\u001b[39mself\u001b[39;49m)\n\u001b[1;32m 137\u001b[0m host, port \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_address[:\u001b[39m2\u001b[39m]\n\u001b[1;32m 138\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_name \u001b[39m=\u001b[39m socket\u001b[39m.\u001b[39mgetfqdn(host)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/socketserver.py:472\u001b[0m, in \u001b[0;36mTCPServer.server_bind\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mallow_reuse_port \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(socket, \u001b[39m\"\u001b[39m\u001b[39mSO_REUSEPORT\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 471\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msocket\u001b[39m.\u001b[39msetsockopt(socket\u001b[39m.\u001b[39mSOL_SOCKET, socket\u001b[39m.\u001b[39mSO_REUSEPORT, \u001b[39m1\u001b[39m)\n\u001b[0;32m--> 472\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msocket\u001b[39m.\u001b[39;49mbind(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mserver_address)\n\u001b[1;32m 473\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_address \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msocket\u001b[39m.\u001b[39mgetsockname()\n", + "\u001b[0;31mOSError\u001b[0m: [Errno 98] Address already in use" ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -2195,7 +2184,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2578,7 +2567,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2774,7 +2763,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -2828,7 +2817,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2860,7 +2849,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -2949,7 +2938,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 41, "metadata": {}, "outputs": [ { diff --git a/notebooks/ONNX_export_tutorial.ipynb b/notebooks/ONNX_export_tutorial.ipynb index 304161fce..e7a6659d2 100644 --- a/notebooks/ONNX_export_tutorial.ipynb +++ b/notebooks/ONNX_export_tutorial.ipynb @@ -22,9 +22,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: netron in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (7.2.9)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install netron" ] @@ -95,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "metadata": { "collapsed": false, "pycharm": { @@ -116,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 15, "metadata": { "collapsed": false, "pycharm": { @@ -142,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "metadata": { "collapsed": false, "pycharm": { @@ -157,6 +168,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "Stopping http://localhost:8082\n", "Serving 'quant_linear_qcdq.onnx' at http://localhost:8082\n" ] }, @@ -175,10 +187,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -219,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "metadata": { "collapsed": false, "pycharm": { @@ -248,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 18, "metadata": { "collapsed": false, "pycharm": { @@ -263,6 +275,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "Stopping http://localhost:8083\n", "Serving 'quant_model_qcdq.onnx' at http://localhost:8083\n" ] }, @@ -281,10 +294,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -334,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 19, "metadata": { "collapsed": false, "pycharm": { @@ -365,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "metadata": { "collapsed": false, "pycharm": { @@ -398,10 +411,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -446,7 +459,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 21, "metadata": { "collapsed": false, "pycharm": { @@ -458,7 +471,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" ] }, @@ -467,7 +480,7 @@ "text/plain": [ "ir_version: 7\n", "producer_name: \"pytorch\"\n", - "producer_version: \"1.13.1\"\n", + "producer_version: \"2.1.0\"\n", "graph {\n", " node {\n", " output: \"/input_quant/export_handler/Constant_output_0\"\n", @@ -496,7 +509,7 @@ " }\n", " }\n", " node {\n", - " input: \"inp.1\"\n", + " input: \"out.1\"\n", " input: \"/input_quant/export_handler/Constant_output_0\"\n", " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", " output: \"/input_quant/export_handler/QuantizeLinear_output_0\"\n", @@ -515,7 +528,7 @@ " dims: 3\n", " dims: 3\n", " data_type: 3\n", - " raw_data: \"\\374\\372\\376\\374\\005\\000\\375\\374\\004\\375\\373\\373\\375\\007\\376\\374\\377\\000\\000\\373\\373\\004\\005\\371\\003\\375\\004\\373\\004\\374\\000\\006\\002\\003\\003\\005\\004\\377\\005\\000\\373\\376\\375\\376\\002\\376\\004\\377\\003\\005\\375\\371\\006\\373\\003\\007\\377\\374\\005\\375\\375\\006\\375\\377\\374\\001\\005\\371\\006\\005\\007\\376\\376\\372\\376\\004\\001\\374\\002\\373\\373\\376\\002\\376\\375\\377\\001\\376\\006\\371\\002\\000\\004\\005\\005\\000\\004\\373\\004\\002\\003\\000\\374\\376\\005\\000\\004\\372\\004\\000\\373\\000\\006\\377\\002\\005\\004\\005\\374\\000\\007\\377\\374\\371\\373\\007\\004\\376\\372\\001\\005\\001\\372\\377\\003\\001\\375\\006\\372\\377\\006\\003\\006\\004\\001\\004\\372\\005\\006\\003\\376\\373\\374\\375\\376\\005\\000\\004\\377\\372\\373\\000\\007\\377\\373\\003\\373\\376\\374\\374\\377\\375\\377\\003\\372\\005\\004\\007\\003\\375\\377\\001\\007\\377\\373\\374\\000\\377\\376\\374\\373\\377\\373\\375\\003\\004\\004\\376\\004\\377\\375\\003\\003\\377\\004\\000\\005\\004\\000\\372\\005\\007\\003\\004\\377\\373\\003\\371\\373\\002\\377\\006\\006\\007\\377\\376\\375\\002\\006\\005\\004\\374\\002\\000\\373\\004\\002\\002\\374\\371\\372\\371\\375\\001\\004\\000\\006\\376\\377\\002\\000\\372\\001\\001\\375\\007\\376\\005\\001\\373\\003\\374\\005\\003\\007\\005\\372\\004\\006\\375\\005\\003\\001\\373\\376\\374\\002\\376\\377\\376\\000\\006\\001\\375\\376\\377\\374\\000\\005\\002\\005\\006\\371\\375\\005\\375\\376\\374\\004\\001\\003\\001\\372\\005\\007\\371\\005\\000\\372\\001\\001\\371\\007\\374\\372\\373\\373\\372\\376\\004\\000\\002\\375\\376\\000\\004\\003\\003\\375\\003\\001\\376\\006\\001\\000\\372\\374\\376\\373\\002\\002\\004\\372\\377\\374\\005\\000\\001\\005\\005\\374\\007\\003\\377\\377\\000\\007\\002\\377\\377\\377\\374\\001\\001\\376\\000\\377\\373\\001\\004\\376\\003\\000\\007\\005\\000\\374\\372\\376\\005\\003\\003\\004\\372\\375\\372\\377\\006\\376\\374\\007\\373\\002\\374\\003\\377\\374\\002\\007\\373\\004\\376\\004\\004\\003\\005\\373\\003\\005\\376\\001\\000\\002\\371\\376\\000\\374\\377\\372\\375\\005\\373\\002\\373\\373\\377\\004\\375\\006\\377\\005\\005\\002\\375\\375\\003\\376\\376\\006\\002\\371\\000\\002\\373\\000\\006\\002\\372\\372\\006\\374\\372\\004\\006\\004\\000\\003\\001\\377\\371\\376\\003\\003\\373\\005\\000\\001\\003\\004\\001\\005\\001\\004\\373\\373\\372\\002\\371\\375\\372\\004\\377\\005\\375\\376\\374\\375\\003\\372\\001\\373\\372\\376\\005\\003\\372\\004\\373\\004\\374\\374\\376\\376\\377\\371\\375\\004\\375\\377\\376\\007\\004\\372\\000\\007\\372\\006\\002\\006\\001\\006\\372\\004\\004\\003\\002\\375\\006\\374\\002\\001\\001\\000\\376\\376\\006\\373\\374\\002\\372\\005\\374\\004\\004\\001\\374\\004\\377\\373\\002\\376\\001\\377\\003\\377\\007\\004\\372\\371\\002\\375\\377\\373\\002\\376\\375\\377\\006\\001\\001\\000\\374\\001\\006\\004\\371\\377\\375\\374\\377\\376\\003\\372\\373\\002\\005\\374\\000\\002\\004\\372\\004\\372\\003\\006\\375\\003\\377\\376\\000\\377\\374\\006\\377\\374\\375\\377\\373\\376\\372\\375\\006\\004\\371\\372\\374\\375\\004\\002\\372\\376\\001\\001\\002\\373\\000\\003\\000\\371\\001\\003\\377\\376\\371\\376\\004\\000\\003\\376\\002\\006\\004\\372\\007\\005\\004\\376\\000\\007\\372\\003\\002\\005\\005\\004\\372\\002\\377\\006\\002\\371\\375\\375\\372\\376\\005\\003\\000\\002\\371\\005\\372\\373\\377\\371\\376\\005\\374\\377\\007\\003\\001\\376\\006\\376\\001\\374\\374\\001\\373\\006\\376\\376\\001\\372\\377\\003\\006\\372\\373\\003\\377\\376\\000\\377\\373\\004\\372\\371\\376\\002\\004\\004\\006\\001\\372\\001\\376\\005\\001\\000\\000\\007\\002\\375\\002\\375\\375\\006\\007\\375\\375\\002\\006\\371\\375\\002\\377\\002\\377\\000\\373\\001\\372\\372\\001\\377\\372\\001\\002\\000\\375\\373\\377\\372\\001\\371\\372\\007\\372\\001\\377\\372\\004\\376\\376\\374\\375\\373\\373\\005\\371\\375\\006\\005\\007\\374\\373\\005\\372\\000\\001\\374\\005\\000\\002\\373\\004\\001\\004\\006\\002\\003\\373\\376\\372\\374\\003\\375\\005\\000\\005\\373\\001\\375\\374\\002\\002\\000\\373\\374\\003\\005\\376\\003\\374\\374\\373\\374\\000\\004\\371\\375\\372\\003\\375\\005\\005\\006\\007\\371\\003\\372\\003\\375\\004\\374\\001\\376\\373\\000\\004\\003\\001\\003\\372\\377\\003\\004\\374\\000\\376\\002\\377\\001\\374\\376\\002\\002\\001\\005\\375\\373\\001\\372\\000\\007\\004\\007\\006\\006\\000\\004\\004\\006\\000\\377\\375\\000\\002\\374\\376\\374\\006\\373\\377\\000\\374\\006\\373\\005\\001\\001\\006\\005\\373\\373\\001\\003\\371\\006\\372\\003\\005\\372\\003\\005\\006\\005\\006\\001\\001\\377\\372\\001\\003\\372\\005\\002\\376\\377\\373\\005\\376\\375\\373\\005\\004\\007\\001\\000\\002\\001\\374\\004\\003\\377\\004\\372\\373\\373\\007\\375\\002\\002\\377\\373\\007\\001\\004\\374\\007\\376\\000\\003\\376\\006\\371\\377\\003\\376\\003\\004\\375\\006\\376\\371\\373\\373\\004\\000\\005\\377\\372\\372\\377\\004\\002\\001\\000\\005\\372\\004\\377\\376\\375\\001\\005\\375\\375\\000\\003\\006\\374\\004\\377\\004\\006\\000\\374\\003\\000\\005\\376\\372\\371\\371\\000\\374\\372\\006\\004\\006\\376\\377\\001\\377\\376\\373\\374\\000\\003\\004\\372\\375\\000\\006\\002\\374\\377\\004\\372\\371\\373\\001\\006\\377\\003\\007\\377\\373\\000\\371\\002\\376\\003\\002\\377\\006\\006\\006\\371\\006\\373\\377\\006\\000\\374\\375\\376\\001\\376\\007\\003\\007\\376\\004\\001\\005\\003\\375\\372\\003\\004\\376\\374\\005\\372\\372\\000\\006\\377\\003\\000\\002\\001\\003\\375\\000\\004\\375\\372\\000\\001\\000\\000\\002\\000\\004\\005\\377\\005\\007\\376\\372\\001\\374\\006\\002\\376\\002\\005\\374\\372\\000\\375\\372\\372\\000\\001\\000\\377\\007\\376\\000\\374\\375\\000\\373\\003\\001\\006\\003\\376\\007\\374\\376\\374\\005\\371\\372\\001\\374\\374\\002\\375\\004\\001\\002\\002\\376\\003\\373\\000\\375\\375\\005\\373\\002\\376\\371\\006\\004\\001\\001\\371\\376\\005\\377\\375\\005\\003\\374\\375\\002\\373\\376\\001\\002\\001\\007\\002\\004\\376\\375\\377\\376\\004\\373\\000\\001\\375\\377\\372\\376\\002\\001\\375\\006\\005\\006\\004\\376\\004\\004\\001\\001\\377\\004\\006\\003\\001\\005\\006\\001\\377\\000\\000\\000\\372\\004\\375\\004\\377\\377\\006\\377\\373\\003\\375\\373\\004\\005\\377\\006\\376\\374\\374\\371\\376\\003\\376\\374\\001\\373\\001\\375\\001\\376\\376\\000\\376\\371\\376\\377\\372\\373\\374\\374\\375\\376\\003\\376\\002\\372\\375\\375\\007\\377\\373\\377\\006\\376\\377\\373\\002\\001\\000\\005\\004\\006\\376\\001\\373\\372\\371\\001\\371\\001\\373\\374\\001\\375\\373\\003\\375\\373\\005\\373\\004\\377\\002\\000\\002\\006\\001\\373\\375\\005\\376\\004\\000\\376\\003\\007\\000\\377\\003\\004\\005\\376\\004\\003\\004\\006\\006\\006\\371\\002\\374\\375\\003\\375\\000\\375\\377\\004\\003\\374\\373\\004\\005\\375\\003\\376\\001\\001\\374\\003\\377\\004\\006\\003\\377\\001\\003\\377\\377\\371\\000\\374\\003\\373\\374\\006\\372\\372\\006\\004\\375\\375\\373\\004\\005\\001\\373\\371\\377\\376\\004\\005\\373\\374\\005\\000\\376\\001\\002\\003\\006\\006\\374\\375\\374\\377\\001\\373\\003\\004\\372\\004\\375\\001\\371\\004\\002\\001\\376\\377\\005\\000\\376\\376\\372\\005\\000\\376\\004\\371\\000\\377\\377\\377\\373\\377\\001\\004\\002\\374\\373\\000\\374\\377\\373\\373\\374\\005\\006\\374\\003\\373\\000\\006\\001\\003\\371\\373\\006\\374\\005\\005\\006\\371\\002\\005\\373\\000\\377\\377\\003\\005\\003\\004\\376\\372\\000\\005\\004\\371\\372\\376\\371\\005\\375\\000\\001\\001\\000\\006\\005\\006\\002\\002\\000\\003\\006\\374\\005\\000\\373\\372\\376\\002\\372\\006\\003\\375\\373\\373\\375\\002\\004\\001\\007\\373\\377\\004\\005\\004\\375\\005\\376\\376\\004\\003\\000\\004\\376\\006\\001\\376\\003\\376\\007\\006\\002\\376\\001\\376\\006\\371\\006\\375\\375\\004\\003\\006\\377\\374\\004\\003\\375\\372\\374\\375\\006\\377\\000\\004\\373\\002\\006\\373\\377\\374\\372\\000\\000\\376\\006\\373\\372\\004\\001\\006\\003\\377\\006\\371\\006\\006\\004\\004\\005\\371\\376\\001\\003\\372\\005\\001\\002\\373\\001\\372\\375\\004\\372\\006\\373\\375\\001\\003\\375\\377\\003\\372\\374\\374\\373\\006\\005\\373\\002\\000\\004\\376\\377\\004\\374\\006\\374\\006\\373\\004\\375\\373\\006\\376\\006\\002\\002\\377\\372\\372\\005\\004\\375\\000\\002\\374\\002\\376\\007\\373\\376\\371\\377\\005\\376\\006\\002\\006\\376\\004\\372\\000\\005\\002\\002\\003\\006\\004\\377\\007\\374\\372\\372\\002\\375\\377\\001\\375\\005\\374\\377\\003\\007\\002\\005\\006\\006\\000\\001\\004\\000\\376\\371\\001\\000\\005\\004\\375\\372\\375\\004\\007\\371\\374\\002\\005\\000\\002\\002\\004\\004\\007\\005\\006\\373\\006\\004\\002\\005\\004\\376\\375\\000\\372\\004\\377\\003\\374\\003\\376\\006\\376\\006\\006\\005\\002\\006\\007\\002\\372\\372\\377\\373\\004\\373\\375\\004\\004\\003\\006\\002\\000\\002\\376\\000\\000\\005\\006\\005\\372\\003\\372\\006\\001\\007\\372\\002\\372\\004\\001\\005\\002\\005\\374\\372\\372\\002\\372\\001\\377\\002\\006\\005\\000\\005\\372\\375\\007\\377\\375\\004\\005\\003\\372\\004\\005\\376\\373\\001\\372\\003\\371\\371\\374\\005\\002\\005\\374\\377\\004\\002\\376\\004\\373\\377\\377\\377\\001\\005\\372\\003\\373\\375\\006\\374\\007\\376\\372\\006\\005\\371\\377\\005\\001\\003\\005\\002\\006\\003\\001\\377\\374\\004\\376\\374\\375\\376\\001\\001\\001\\004\\007\\007\\000\\005\\001\\376\\003\\376\\000\\000\\001\\001\\375\\371\\006\\002\\001\\373\\000\\377\\007\\004\\002\\374\\000\\001\\377\\003\\374\\003\\007\\373\\371\\373\\001\\005\\373\\372\\005\\373\\375\\005\\006\\372\\000\\005\\007\\003\\003\\377\\005\\006\\004\\374\\372\\375\\003\\004\\000\\005\\376\\374\\374\\375\\375\\377\\372\\000\\004\\002\\005\\002\\000\\374\\376\\373\\373\\376\\002\\374\\000\\376\\373\\000\\371\\373\\372\\006\\000\\376\\002\\375\\376\\005\\372\\004\\376\\375\\005\\006\\006\\004\\003\\002\\002\\002\\002\\375\\006\\377\\000\\004\\375\\004\\007\\004\\005\\372\\374\\004\\377\\003\\377\\000\\375\\372\\374\\372\\000\\004\\007\\002\\007\\372\\376\\004\\371\\375\\001\\001\\007\\003\\000\\004\\373\\001\\001\\376\\002\\377\\377\\006\\002\\003\\373\\373\\004\\372\\372\\376\\372\\002\\002\\002\\373\\001\\375\\374\\000\\004\\003\\376\\003\\376\\002\\373\\374\\003\\372\\371\\001\\375\\004\\371\\374\\004\\005\\002\\374\\371\\001\\373\\377\\374\\006\\373\\006\\000\\005\\005\\006\\006\\002\\375\\002\\001\\001\\005\\375\\000\\372\\371\\003\\004\\375\\376\\003\\377\\374\\005\\007\\007\\377\\374\\375\\374\\376\\373\\003\\002\\002\\374\\377\\373\\004\\375\\372\\374\\003\\374\\005\\376\\002\\373\\376\\006\\005\\374\\002\\371\\005\\004\\001\\373\\000\\377\\374\\003\\000\\001\\001\\003\\372\\005\\001\\371\\371\\000\\375\\001\\375\\372\\374\\003\\373\\376\\001\\371\\006\\005\\004\\377\\004\\376\\377\\377\\003\\373\\001\\372\\376\\006\\372\\372\\005\\374\\001\\374\\004\\001\\004\\375\\002\\002\\373\\006\\000\\001\\002\\377\\371\\005\\005\\374\\374\\006\\003\\001\\002\\001\\374\\377\\372\\000\\377\\374\\373\\371\\007\\003\\375\\373\\374\\373\\374\\005\\004\\005\\006\\002\\374\\000\\372\\001\\376\\002\\373\\371\\372\\374\\374\\377\\005\\375\\371\\002\\374\\374\\005\\377\\007\\004\\376\\007\\373\\372\\007\\007\\377\\004\\002\\002\\007\\377\\375\\002\\005\\006\\003\\002\\006\\376\\003\\004\\003\\000\\371\\002\\002\\374\\006\\373\\005\\003\\003\\002\\003\\376\\002\\004\\377\\377\\371\\007\\001\\373\\376\\003\\002\\007\\376\\002\\005\\004\\374\\003\\377\\374\\003\\007\\004\\377\\002\\001\\003\\005\\373\\377\\374\\002\\377\\004\\000\\000\\005\\007\\002\\003\\376\\371\\377\\006\\372\\372\\002\\372\\371\\375\\000\\376\\005\\372\\000\\373\\372\\007\\002\\001\\372\\374\\375\\005\\005\\004\\001\\002\\002\\006\\372\\001\\007\\373\\375\\000\\372\\005\\003\\000\\375\\377\\001\\003\\006\\000\\376\\374\\002\\375\\375\\003\\001\\007\\376\\377\\003\\000\\005\\376\\374\\005\\373\\004\\377\\000\\375\\002\\005\\001\\001\\000\\001\\375\\374\\001\\006\\372\\375\\376\\372\\371\\001\\372\\005\\004\\376\\373\\006\\005\\375\\006\\377\\001\\001\\000\\006\\000\\006\\007\\003\\372\\004\\375\\373\\372\\372\\000\\374\\001\\006\\007\\376\\374\\371\\373\\372\\375\\003\\377\\372\\377\\005\\002\\006\\372\\006\\004\\005\\000\\376\\007\\003\\372\\004\\377\\006\\001\\373\\375\\374\\373\\373\\004\\004\\375\\373\\005\\376\\000\\001\\375\\371\\372\\005\\375\\000\\002\\372\\003\\004\\372\\003\\374\\005\\002\\374\\377\\001\\005\\376\\377\\374\\376\\005\\376\\372\\003\\373\\372\\006\\372\\377\\373\\006\\372\\004\\006\\373\\005\\375\\375\\007\\374\\005\\002\\374\\374\\002\\002\\377\\375\\376\\372\\005\\375\\371\\003\\005\\003\\372\\377\\375\\372\\002\\005\\000\\006\\372\\005\\371\\376\\000\\001\\377\\004\\004\\006\\000\\377\\007\\002\\006\\000\\371\\375\\374\\374\\001\\373\\371\\002\\376\\002\\000\\374\\006\\001\\374\\006\\005\\001\\003\\376\\003\\374\\003\\374\\002\\007\\373\\002\\004\\007\\005\\374\\376\\372\\372\\001\\371\\002\\005\\373\\376\\006\\375\\372\\376\\004\\003\\001\\004\\376\\002\\373\\006\\006\\371\\372\\003\\004\\006\\375\\004\\007\\371\\000\\000\\001\\000\\374\\001\\006\\002\\006\\002\\000\\002\\373\\372\\372\\000\\372\\005\\006\\004\\000\\376\\372\\373\\006\\007\\373\\006\\373\\377\\003\\375\\373\\001\\377\\001\\002\\376\\003\\373\\002\\376\\007\\371\\371\\374\\006\\377\\001\\002\\005\\001\\376\\375\\000\\377\\371\\005\\372\\002\\377\\375\\375\\002\\375\\376\\003\\003\\373\\373\\005\\004\\004\\373\\000\\000\\007\\003\\372\\375\\004\\003\\376\\377\\373\\376\\004\\372\\004\\377\\376\\007\\002\\005\\003\\001\\006\\006\\002\\005\\373\\000\\004\\000\\004\\374\\372\\376\\007\\002\\003\\006\\002\\000\\372\\001\\374\\005\\376\\006\\007\\373\\001\\375\\004\\377\\374\\375\\377\\001\\377\\003\\375\\005\\000\\003\\376\\375\\003\\377\\372\\002\\006\\003\\007\\005\\374\\003\\006\\003\\000\\375\\000\\001\\000\\001\\002\\374\\377\\372\\004\\372\\377\\377\\003\\377\\007\\006\\371\\003\\005\\004\\007\\006\\371\\006\\001\\375\\001\\001\\376\\002\\374\\006\\375\\375\\376\\377\\002\\002\\007\\373\\373\\374\\373\\377\\001\\006\\375\\375\\001\\375\\373\\375\\373\\372\\376\\003\\371\\006\\376\\376\\375\\007\\377\\374\\376\\377\\006\\377\\001\\371\\377\\007\\375\\371\\005\\002\\373\\003\\005\\002\\371\\375\\003\\003\\003\\374\\000\\377\\375\\003\\002\\006\\006\\375\\006\\002\\000\\374\\373\\374\\002\\003\\373\\002\\375\\377\\004\\006\\003\\006\\000\\377\\372\\375\\375\\002\\002\\003\\006\\003\\003\\377\\373\\003\\003\\003\\003\\377\\004\\004\\372\\377\\000\\374\\375\\005\\004\\005\\003\\002\\375\\376\\001\\376\\003\\374\\002\\007\\002\\376\\377\\007\\006\\376\\372\\374\\004\\371\\004\\006\\006\\374\\374\\377\\374\\003\\006\\371\\377\\007\\372\\375\\006\\374\\374\\005\\372\\006\\372\\371\\001\\000\\375\\372\\374\\373\\374\\374\\374\\005\\004\\002\\375\\004\\007\\004\\006\\002\\005\\005\\372\\375\\000\\004\\000\\377\\004\\004\\001\\374\\377\\006\\003\\377\\374\\000\\376\\372\\376\\373\\377\\006\\377\\376\\002\\005\\005\\372\\004\\000\\001\\004\\005\\373\\005\\003\\371\\374\\373\\000\\375\\002\\375\\006\\003\\001\\004\\377\\374\\372\\005\\006\\005\\005\\005\\005\\007\\372\\006\\004\\006\\372\\372\\002\\373\\371\\001\\004\\006\\374\\005\\373\\004\\006\\001\\005\\006\\377\\006\\373\\001\\373\\373\\376\\375\\007\\372\\374\\372\\377\\004\\006\\004\\375\\374\\000\\007\\005\\000\\002\\377\\002\\372\\002\\001\\377\\372\\006\\002\\001\\000\\376\\375\\374\\003\\376\\371\\005\\001\\000\\002\\372\\373\\375\\004\\376\\371\\374\\376\\000\\004\\004\\376\\375\\007\\374\\377\\375\\377\\001\\003\\005\\372\\002\\376\\003\\003\\375\\001\\004\\001\\001\\000\\002\\004\\375\\375\\372\\003\\003\\372\\002\\375\\372\\377\\373\\000\\002\\371\\005\\003\\001\\001\\376\\372\\374\\001\\001\\376\\000\\001\\376\\001\\376\\005\\002\\374\\002\\004\\004\\000\\374\\007\\000\\000\\006\\003\\371\\376\\371\\006\\005\\006\\007\\002\\371\\373\\005\\372\\375\\006\\003\\373\\005\\375\\375\\373\\002\\000\\375\\005\\001\\372\\377\\377\\373\\375\\375\\374\\000\\376\\372\\000\\374\\001\\001\\372\\375\\373\\004\\374\\000\\006\\375\\004\\001\\006\\000\\373\\001\\375\\003\\372\\000\\373\\376\\003\\374\\005\\007\\377\\373\\007\\006\\002\\371\\373\\377\\004\\373\\001\\374\\000\\001\\004\\001\\005\\375\\372\\002\\376\\377\\371\\374\\375\\371\\373\\005\\376\\374\\001\\377\\376\\371\\375\\371\\000\\375\\373\\377\\006\\002\\003\\005\\372\\003\\004\\005\\005\\004\\000\\376\\372\\371\\006\\000\\377\\373\\003\\376\\005\\007\\006\\372\\004\\007\\374\\375\\376\\374\\000\\001\\001\\375\\003\\371\\001\\006\\374\\376\\006\\377\\000\\001\\375\\006\\004\\372\\371\\001\\377\\377\\377\\376\\006\\375\\372\\000\\371\\376\\002\\374\\372\\006\\372\\002\\006\\005\\001\\376\\004\\374\\002\\376\\000\\004\\376\\375\\000\\376\\004\\000\\006\\372\\005\\007\\006\\002\\004\\373\\373\\006\\003\\007\\001\\375\\007\\007\\372\\004\\005\\376\\005\\376\\007\\002\\376\\004\\373\\373\\376\\004\\372\\375\\373\\374\\001\\000\\375\\004\\375\\375\\377\\004\\001\\377\\002\\376\\004\\377\\001\\001\\374\\376\\374\\377\\377\\001\\000\\000\\377\\373\\374\\002\\006\\001\\375\\376\\000\\000\\374\\006\\004\\004\\004\\375\\001\\376\\001\\002\\373\\006\\006\\376\\002\\005\\005\\374\\373\\377\\376\\004\\005\\374\\000\\376\\002\\375\\376\\004\\373\\001\\377\\377\\002\\377\\373\\372\\371\\003\\003\\372\\006\\000\\002\\003\\005\\375\\371\\375\\004\\376\\374\\007\\375\\371\\002\\374\\000\\375\\005\\006\\374\\373\\004\\371\\000\\007\\376\\001\\375\\377\\372\\372\\373\\005\\005\\001\\372\\377\\371\\377\\375\"\n", + " raw_data: \"\\377\\006\\003\\006\\005\\002\\373\\006\\000\\374\\004\\377\\374\\005\\006\\374\\004\\376\\002\\005\\005\\006\\374\\371\\373\\002\\371\\374\\001\\002\\377\\002\\006\\376\\006\\373\\004\\005\\003\\007\\376\\372\\007\\374\\377\\005\\002\\375\\001\\374\\372\\375\\373\\003\\002\\372\\000\\377\\003\\006\\002\\377\\004\\373\\374\\371\\000\\373\\376\\372\\002\\006\\005\\005\\374\\003\\001\\006\\006\\001\\003\\375\\000\\006\\376\\000\\004\\004\\373\\372\\002\\002\\000\\002\\007\\001\\374\\376\\376\\377\\375\\377\\375\\006\\372\\371\\005\\004\\005\\372\\377\\004\\377\\373\\004\\373\\004\\007\\377\\000\\003\\373\\005\\003\\004\\376\\372\\371\\376\\377\\005\\005\\006\\001\\005\\376\\002\\006\\000\\000\\371\\005\\001\\003\\003\\003\\372\\003\\372\\002\\377\\374\\007\\000\\377\\005\\004\\006\\006\\374\\000\\006\\375\\005\\376\\374\\000\\000\\004\\002\\374\\006\\374\\005\\004\\006\\376\\001\\372\\376\\006\\374\\371\\001\\005\\006\\375\\372\\373\\377\\377\\376\\000\\005\\377\\005\\006\\374\\003\\376\\003\\000\\376\\377\\374\\004\\375\\001\\376\\374\\374\\373\\000\\371\\377\\006\\002\\377\\375\\001\\003\\006\\372\\001\\002\\000\\373\\374\\000\\005\\003\\004\\003\\377\\004\\373\\376\\005\\376\\377\\375\\376\\003\\005\\005\\004\\004\\006\\004\\005\\375\\376\\373\\372\\006\\002\\375\\003\\001\\005\\003\\006\\372\\004\\375\\371\\377\\003\\374\\006\\376\\376\\373\\374\\377\\006\\001\\377\\372\\003\\001\\372\\006\\000\\377\\372\\372\\001\\000\\377\\371\\002\\373\\006\\373\\372\\002\\375\\374\\373\\376\\377\\377\\375\\374\\006\\003\\372\\003\\000\\007\\005\\371\\001\\372\\006\\373\\376\\006\\376\\000\\372\\007\\000\\000\\006\\006\\373\\006\\372\\005\\374\\376\\376\\373\\372\\006\\004\\372\\003\\377\\000\\005\\002\\374\\375\\373\\000\\000\\006\\004\\376\\001\\001\\003\\374\\003\\004\\004\\377\\007\\377\\376\\006\\005\\000\\371\\006\\002\\003\\377\\374\\377\\372\\377\\005\\377\\003\\000\\006\\375\\004\\372\\376\\005\\372\\001\\005\\006\\005\\002\\004\\001\\001\\007\\375\\004\\002\\007\\374\\006\\005\\000\\002\\372\\001\\377\\002\\372\\005\\007\\373\\377\\374\\001\\004\\005\\372\\001\\001\\377\\002\\007\\003\\373\\376\\377\\005\\006\\374\\002\\375\\002\\373\\001\\004\\377\\373\\003\\377\\006\\006\\376\\004\\373\\003\\375\\376\\374\\375\\376\\007\\000\\377\\371\\004\\373\\000\\374\\002\\377\\375\\005\\006\\372\\003\\002\\376\\007\\002\\003\\001\\000\\006\\004\\000\\001\\000\\376\\377\\000\\004\\006\\001\\000\\373\\006\\374\\000\\375\\377\\004\\006\\373\\006\\003\\006\\373\\373\\376\\007\\375\\377\\004\\374\\003\\001\\376\\374\\372\\001\\375\\004\\003\\002\\003\\373\\001\\003\\374\\001\\372\\003\\003\\004\\372\\007\\005\\004\\373\\372\\002\\377\\007\\003\\001\\001\\373\\375\\373\\373\\372\\375\\376\\375\\005\\376\\373\\374\\374\\000\\002\\373\\006\\003\\000\\005\\005\\000\\007\\004\\377\\373\\372\\004\\375\\375\\002\\007\\376\\000\\006\\376\\373\\372\\001\\373\\000\\377\\006\\002\\006\\375\\376\\002\\004\\006\\373\\001\\002\\372\\005\\376\\002\\001\\373\\377\\004\\001\\374\\373\\002\\004\\002\\377\\004\\372\\377\\373\\004\\375\\001\\372\\006\\376\\007\\371\\006\\003\\006\\373\\000\\377\\375\\005\\374\\005\\374\\005\\373\\372\\000\\372\\371\\002\\372\\375\\372\\377\\005\\371\\004\\375\\000\\006\\002\\006\\377\\375\\006\\006\\004\\005\\374\\372\\372\\372\\004\\377\\005\\377\\372\\375\\374\\371\\376\\000\\004\\005\\005\\003\\373\\371\\000\\375\\001\\376\\372\\006\\376\\374\\005\\005\\372\\372\\005\\003\\000\\001\\376\\372\\377\\004\\376\\001\\377\\375\\005\\005\\371\\377\\371\\377\\374\\007\\004\\007\\000\\377\\000\\376\\001\\376\\004\\375\\006\\003\\001\\005\\373\\004\\376\\005\\003\\377\\377\\001\\004\\375\\375\\376\\377\\373\\376\\000\\000\\005\\374\\372\\375\\000\\376\\000\\002\\376\\005\\004\\377\\004\\372\\006\\375\\377\\372\\376\\004\\374\\003\\004\\006\\375\\376\\003\\371\\374\\374\\000\\000\\371\\006\\002\\003\\376\\374\\374\\001\\375\\004\\003\\372\\007\\004\\005\\006\\004\\007\\372\\376\\371\\000\\007\\005\\005\\005\\001\\374\\374\\377\\006\\003\\000\\001\\004\\372\\375\\005\\003\\002\\374\\004\\005\\371\\373\\373\\377\\374\\372\\376\\002\\372\\377\\004\\000\\003\\002\\001\\004\\377\\374\\002\\004\\374\\377\\376\\006\\005\\002\\004\\005\\003\\004\\373\\377\\004\\373\\004\\003\\004\\001\\375\\005\\004\\001\\376\\005\\005\\000\\375\\374\\001\\373\\006\\000\\376\\005\\377\\001\\002\\374\\007\\005\\002\\000\\371\\375\\007\\000\\000\\005\\372\\002\\373\\004\\000\\374\\375\\006\\371\\007\\004\\007\\374\\001\\000\\006\\376\\375\\371\\375\\372\\002\\003\\004\\375\\002\\005\\373\\372\\377\\004\\373\\001\\004\\003\\007\\002\\373\\372\\373\\376\\374\\004\\003\\002\\000\\376\\375\\006\\376\\373\\006\\371\\006\\005\\005\\006\\000\\000\\377\\001\\372\\005\\377\\005\\376\\001\\373\\376\\001\\375\\371\\375\\372\\373\\002\\374\\002\\000\\006\\377\\003\\004\\371\\001\\375\\002\\004\\373\\001\\000\\371\\375\\372\\004\\003\\372\\002\\002\\374\\002\\001\\004\\371\\006\\007\\003\\373\\004\\003\\376\\005\\005\\003\\373\\374\\003\\004\\376\\000\\007\\002\\005\\376\\006\\004\\001\\375\\004\\377\\375\\006\\375\\001\\005\\376\\373\\377\\003\\374\\000\\371\\000\\006\\000\\007\\374\\376\\377\\004\\374\\004\\374\\374\\374\\004\\376\\001\\002\\376\\000\\000\\001\\002\\373\\377\\004\\007\\376\\373\\374\\375\\006\\376\\004\\007\\001\\001\\373\\372\\003\\001\\002\\003\\375\\000\\373\\004\\376\\000\\373\\377\\376\\376\\377\\001\\005\\000\\006\\372\\002\\006\\377\\376\\004\\003\\002\\376\\004\\006\\001\\002\\374\\000\\374\\005\\377\\375\\372\\006\\003\\373\\376\\004\\372\\002\\003\\377\\006\\376\\375\\377\\375\\001\\374\\004\\005\\000\\001\\372\\376\\376\\006\\377\\374\\001\\006\\375\\374\\373\\372\\002\\001\\001\\373\\006\\374\\375\\001\\377\\004\\001\\007\\371\\001\\000\\376\\006\\376\\375\\003\\374\\371\\376\\377\\376\\000\\001\\373\\006\\376\\375\\373\\002\\374\\375\\375\\001\\000\\003\\007\\004\\373\\003\\377\\372\\000\\376\\376\\371\\006\\373\\001\\377\\002\\003\\001\\377\\004\\007\\002\\375\\376\\004\\002\\005\\002\\001\\376\\001\\006\\002\\002\\375\\372\\003\\377\\000\\375\\004\\375\\377\\373\\374\\376\\001\\002\\373\\377\\377\\003\\005\\004\\373\\006\\001\\001\\003\\000\\005\\003\\377\\376\\377\\002\\004\\373\\003\\006\\004\\372\\004\\003\\002\\371\\376\\377\\377\\371\\373\\371\\377\\374\\006\\373\\005\\007\\372\\373\\377\\003\\003\\374\\377\\007\\004\\376\\000\\000\\003\\372\\007\\001\\372\\004\\000\\001\\003\\375\\005\\007\\001\\376\\001\\377\\371\\377\\004\\007\\374\\373\\373\\006\\007\\005\\001\\376\\376\\005\\373\\001\\005\\006\\004\\005\\372\\373\\002\\004\\006\\377\\375\\005\\376\\000\\373\\005\\006\\003\\002\\000\\372\\001\\001\\000\\000\\007\\372\\001\\374\\006\\003\\005\\376\\003\\002\\377\\373\\372\\375\\371\\377\\001\\374\\377\\001\\371\\006\\001\\376\\374\\006\\375\\001\\000\\007\\000\\375\\376\\376\\377\\001\\374\\000\\371\\373\\374\\003\\006\\006\\371\\001\\001\\376\\006\\377\\001\\375\\002\\376\\000\\377\\006\\000\\004\\372\\000\\000\\375\\000\\003\\002\\004\\372\\000\\001\\372\\002\\004\\003\\374\\373\\005\\006\\376\\007\\000\\000\\373\\003\\000\\007\\377\\376\\372\\007\\376\\003\\001\\374\\001\\006\\006\\001\\372\\002\\371\\006\\005\\374\\005\\005\\377\\373\\373\\003\\006\\002\\376\\371\\007\\374\\006\\372\\377\\375\\002\\000\\006\\006\\377\\373\\001\\372\\375\\004\\377\\372\\372\\001\\375\\003\\000\\373\\000\\373\\001\\004\\371\\377\\377\\372\\005\\372\\004\\005\\007\\002\\372\\001\\371\\002\\003\\006\\376\\372\\006\\373\\375\\376\\000\\373\\376\\007\\377\\000\\375\\000\\371\\006\\373\\007\\002\\004\\376\\372\\004\\002\\003\\000\\373\\005\\376\\377\\001\\004\\372\\377\\000\\003\\000\\373\\004\\005\\375\\006\\374\\004\\376\\003\\375\\374\\372\\001\\003\\374\\000\\002\\001\\004\\002\\374\\003\\001\\006\\374\\372\\003\\006\\375\\377\\374\\006\\001\\375\\005\\375\\002\\373\\007\\004\\373\\003\\372\\004\\374\\004\\373\\007\\000\\007\\377\\376\\374\\371\\004\\001\\375\\373\\005\\007\\377\\371\\372\\005\\372\\004\\377\\374\\372\\001\\002\\000\\001\\375\\003\\374\\375\\375\\376\\003\\006\\372\\006\\002\\006\\377\\375\\377\\005\\006\\005\\374\\377\\372\\373\\004\\003\\376\\006\\373\\006\\374\\002\\006\\005\\006\\371\\005\\004\\372\\001\\004\\371\\003\\005\\004\\374\\003\\373\\376\\374\\005\\003\\000\\006\\373\\006\\376\\001\\376\\006\\372\\371\\005\\372\\375\\374\\002\\003\\375\\372\\000\\001\\001\\006\\000\\002\\374\\373\\377\\373\\001\\375\\000\\000\\001\\003\\006\\374\\002\\375\\000\\375\\002\\005\\001\\004\\000\\377\\376\\005\\371\\377\\000\\002\\376\\372\\004\\376\\372\\372\\003\\004\\003\\375\\001\\376\\002\\003\\371\\372\\377\\375\\005\\004\\376\\005\\004\\004\\376\\002\\372\\001\\373\\002\\000\\006\\376\\375\\007\\001\\000\\002\\374\\000\\377\\005\\372\\003\\000\\000\\000\\005\\006\\002\\001\\004\\000\\376\\375\\006\\004\\374\\376\\006\\002\\007\\006\\377\\006\\006\\376\\000\\002\\004\\374\\005\\373\\004\\375\\371\\376\\006\\000\\373\\376\\376\\003\\007\\371\\377\\005\\376\\000\\005\\001\\375\\371\\001\\376\\373\\006\\005\\000\\376\\005\\001\\001\\376\\002\\002\\001\\375\\375\\372\\373\\004\\372\\000\\000\\006\\005\\375\\003\\005\\006\\372\\003\\001\\006\\377\\003\\003\\002\\001\\377\\004\\374\\006\\003\\374\\004\\373\\374\\006\\376\\005\\003\\374\\377\\376\\003\\006\\000\\375\\007\\003\\375\\371\\373\\374\\006\\004\\004\\373\\373\\374\\005\\001\\006\\005\\373\\000\\372\\371\\000\\004\\002\\375\\374\\006\\375\\373\\005\\376\\004\\007\\002\\002\\374\\375\\004\\371\\007\\007\\006\\003\\377\\004\\006\\007\\372\\006\\371\\371\\374\\002\\001\\371\\007\\377\\377\\005\\002\\001\\004\\002\\377\\377\\000\\373\\000\\004\\005\\004\\372\\377\\376\\373\\007\\007\\000\\000\\000\\001\\006\\006\\375\\002\\006\\372\\005\\000\\005\\003\\371\\371\\006\\001\\375\\002\\001\\377\\006\\376\\372\\373\\375\\001\\002\\004\\376\\001\\374\\373\\005\\374\\376\\006\\002\\377\\006\\373\\007\\002\\004\\374\\373\\374\\002\\004\\372\\006\\005\\375\\000\\371\\003\\376\\376\\002\\374\\001\\002\\004\\001\\003\\006\\002\\002\\371\\376\\006\\000\\371\\372\\007\\002\\005\\002\\372\\006\\000\\373\\000\\375\\001\\002\\004\\007\\374\\376\\000\\372\\003\\375\\377\\001\\375\\003\\372\\372\\000\\001\\002\\376\\373\\376\\004\\372\\004\\372\\377\\375\\001\\004\\375\\002\\371\\376\\006\\005\\374\\001\\372\\006\\000\\005\\000\\373\\377\\001\\007\\375\\374\\002\\007\\373\\373\\000\\376\\004\\006\\000\\372\\003\\002\\376\\007\\002\\002\\001\\374\\000\\373\\374\\005\\007\\375\\003\\004\\006\\371\\002\\006\\372\\372\\376\\371\\002\\002\\000\\006\\373\\374\\003\\374\\372\\004\\372\\000\\002\\374\\374\\007\\000\\001\\002\\004\\376\\001\\004\\375\\377\\003\\376\\004\\001\\376\\374\\377\\372\\374\\000\\003\\002\\371\\372\\377\\373\\005\\371\\373\\003\\372\\373\\004\\371\\006\\002\\006\\376\\002\\377\\375\\376\\371\\005\\006\\006\\005\\374\\004\\372\\006\\372\\004\\002\\006\\001\\007\\007\\002\\005\\001\\005\\005\\004\\004\\007\\372\\373\\004\\374\\373\\004\\374\\376\\376\\375\\372\\005\\002\\375\\005\\007\\375\\007\\006\\376\\374\\003\\377\\377\\000\\373\\372\\003\\371\\006\\000\\373\\374\\001\\003\\005\\372\\376\\374\\002\\003\\373\\377\\006\\376\\374\\004\\005\\375\\002\\004\\371\\004\\371\\377\\006\\001\\375\\005\\004\\374\\006\\376\\004\\000\\001\\372\\007\\374\\373\\373\\005\\372\\004\\001\\006\\374\\374\\001\\003\\000\\375\\371\\000\\004\\005\\003\\376\\377\\004\\004\\007\\000\\004\\007\\004\\376\\376\\003\\376\\001\\373\\377\\373\\002\\374\\003\\374\\373\\374\\376\\373\\003\\006\\375\\002\\373\\375\\374\\376\\373\\001\\375\\001\\375\\371\\001\\003\\002\\006\\374\\004\\371\\373\\004\\374\\377\\003\\374\\371\\000\\006\\003\\377\\374\\006\\372\\373\\376\\375\\002\\003\\001\\000\\005\\004\\374\\006\\377\\006\\371\\002\\377\\376\\000\\374\\376\\005\\373\\376\\004\\001\\377\\006\\001\\372\\001\\002\\375\\373\\000\\374\\007\\376\\006\\375\\375\\377\\004\\004\\002\\374\\000\\376\\002\\006\\376\\006\\003\\000\\376\\371\\005\\004\\373\\004\\005\\376\\000\\375\\375\\003\\371\\375\\002\\371\\007\\374\\377\\005\\006\\372\\000\\375\\373\\001\\375\\005\\374\\374\\001\\374\\004\\374\\002\\371\\000\\006\\005\\376\\373\\373\\004\\374\\004\\005\\004\\374\\005\\372\\373\\003\\377\\000\\005\\376\\375\\373\\376\\007\\003\\000\\004\\007\\374\\376\\000\\377\\375\\377\\000\\376\\373\\007\\005\\374\\004\\007\\006\\004\\001\\373\\377\\376\\372\\005\\007\\007\\004\\377\\374\\373\\004\\376\\373\\003\\004\\372\\004\\376\\373\\000\\000\\002\\374\\002\\000\\001\\003\\376\\007\\375\\374\\001\\003\\000\\374\\375\\001\\374\\002\\375\\003\\375\\001\\376\\007\\374\\003\\003\\000\\374\\001\\373\\000\\000\\003\\374\\375\\005\\377\\001\\374\\000\\375\\372\\372\\000\\376\\372\\001\\373\\007\\372\\373\\375\\375\\373\\001\\003\\372\\005\\376\\374\\374\\375\\002\\372\\376\\376\\374\\004\\374\\005\\004\\007\\375\\372\\004\\001\\374\\002\\001\\372\\006\\373\\003\\003\\000\\375\\373\\374\\004\\373\\374\\372\\372\\005\\002\\005\\003\\377\\376\\002\\005\\006\\374\\374\\003\\003\\377\\000\\371\\006\\007\\003\\001\\005\\005\\377\\373\\372\\005\\377\\005\\002\\373\\006\\001\\007\\006\\005\\373\\006\\003\\002\\000\\372\\002\\005\\373\\377\\001\\375\\372\\003\\374\\375\\004\\372\\372\\371\\376\\377\\374\\372\\004\\376\\001\\375\\000\\374\\000\\375\\376\\377\\372\\371\\003\\373\\005\\371\\001\\372\\373\\003\\374\\003\\376\\375\\003\\004\\372\\374\\002\\372\\006\\377\\373\\000\\373\\002\\375\\374\\005\\004\\003\\006\\377\\372\\375\\005\\376\\374\\001\\004\\371\\373\\377\\001\\372\\003\\372\\002\\372\\001\\001\\007\\000\\004\\002\\000\\375\\372\\371\\001\\001\\375\\371\\005\\000\\001\\377\\002\\376\\002\\000\\376\\373\\371\\373\\376\\000\\001\\375\\373\\372\\005\\005\\006\\001\\001\\373\\003\\006\\373\\006\\005\\003\\374\\006\\375\\007\\005\\374\\007\\007\\371\\376\\375\\374\\001\\376\\372\\372\\373\\000\\004\\376\\005\\372\\376\\000\\004\\375\\001\\000\\376\\376\\376\\004\\375\\002\\374\\371\\373\\371\\006\\000\\006\\005\\374\\005\\377\\373\\001\\375\\375\\000\\376\\373\\377\\372\\377\\375\\006\\005\\002\\001\\377\\374\\004\\001\\002\\006\\004\\375\\374\\000\\003\\003\\000\\000\\003\\373\\006\\374\\376\\007\\376\\003\\003\\373\\376\\003\\003\\000\\002\\004\\000\\375\\006\\373\\003\\001\\377\\372\\006\\005\\003\\376\\374\\002\\001\\001\\006\\005\\376\\377\\374\\003\\006\\372\\004\\002\\374\\004\\374\\374\\005\\004\\375\\377\\377\\001\\377\\005\\006\\372\\377\\377\\377\\002\\000\\377\\006\\373\\376\\376\\001\\376\\007\\004\\004\\371\\373\\371\\374\\376\\000\\000\\372\\003\\377\\004\\005\\004\\001\\376\\374\\003\\376\\007\\373\\006\\002\\006\\004\\004\\371\\372\\373\\006\\376\\005\\373\\373\\376\\000\\006\\006\\007\\377\\005\\001\\376\\374\\373\\005\\001\\373\\001\\004\\003\\000\\375\\005\\377\\003\\373\\377\\376\\375\\376\\000\\000\\000\\375\\375\\372\\000\\000\\003\\373\\000\\373\\001\\007\\006\\001\\374\\007\\002\\006\\004\\372\\377\\000\\375\\375\\000\\377\\002\\005\\001\\376\\371\\001\\374\\374\\373\\374\\372\\376\\005\\372\\003\\373\\007\\374\\005\\003\\000\\006\\005\\372\\004\\372\\003\\005\\373\\376\\003\\374\\377\\003\\373\\003\\005\\374\\001\\374\\375\\002\\374\\001\\000\\002\\374\\003\\007\\374\\373\\004\\000\\004\\003\\002\\000\\377\\371\\002\\377\\003\\006\\001\\000\\371\\377\\377\\376\\002\\000\\004\\003\\007\\005\\005\\375\\376\\375\\005\\375\\376\\007\\002\\004\\001\\003\\001\\001\\004\\375\\003\\000\\374\\004\\376\\001\\006\\376\\003\\374\\374\\373\\375\\003\\374\\000\\376\\001\\004\\374\\004\\003\\000\\005\\003\\374\\006\\375\\373\\376\\374\\373\\002\\004\\004\\006\\374\\004\\001\\001\\002\\373\\005\\004\\373\\375\\377\\377\\002\\005\\001\\375\\375\\006\\001\\373\\003\\377\\004\\003\\003\\006\\001\\376\\375\\375\\377\\373\\004\\373\\007\\375\\001\\376\\374\\002\\004\\003\\377\\376\\374\\005\\007\\000\\006\\006\\377\\007\\374\\376\\375\\371\\375\\003\\005\\005\\373\\373\\376\\002\\005\\375\\000\\375\\371\\000\\004\\006\\373\\372\\005\\372\\377\\375\\372\\006\\005\\375\\372\\377\\374\\375\\006\\002\\377\\374\\374\\006\\374\\004\\375\\373\\005\\006\\377\\000\\001\\377\\003\\375\\006\\376\\004\\002\\372\\372\\377\\005\\371\\376\\374\\002\\377\\373\\001\\006\\006\\372\\002\\004\\001\\005\\001\\002\\003\\372\\001\\377\\004\\003\\005\\003\\006\\372\\002\\376\\000\\000\\000\\376\\377\\373\\004\\002\\371\\373\\003\\374\\372\\005\\005\\373\\376\\003\\375\\372\\373\\375\\374\\006\\377\\004\\005\\004\\377\\000\\005\\375\\004\\005\\377\\003\\004\\002\\006\\374\\377\\005\\003\\376\\372\\003\\373\\374\\377\\004\\372\\002\\000\\002\\002\\375\\006\\001\\377\\374\\001\\374\\375\\372\\372\\004\\004\\001\\377\\003\\002\\375\\006\\007\\374\\376\\375\\372\\003\\376\\372\\373\\374\\004\\374\\002\\376\\003\\003\\372\\000\\375\\002\\007\\005\\375\\373\\000\\373\\373\\002\\372\\376\\005\\000\\004\\006\\375\\374\\006\\372\\377\\006\\000\\005\\004\\002\\375\\376\\000\\005\\003\\374\\375\\001\\372\\373\\005\\002\\376\\374\\007\\000\\003\\002\\005\\006\\374\\374\\006\\371\\375\\002\\005\\005\\372\\003\\372\\001\\000\\376\\377\\372\\372\\000\\004\\002\\002\\373\\376\\374\\373\\003\\373\\376\\002\\007\\007\\004\\003\\376\\373\\002\\003\\372\\001\\001\\001\\375\\376\\377\\372\\004\\002\\000\\003\\371\\002\\003\\377\\375\\001\\372\\372\\003\\005\\376\\007\\374\\374\\000\\374\\376\\004\\374\\004\\373\\004\\375\\000\\376\\001\\377\\004\\007\\373\\003\\371\\001\\375\\007\\002\\000\\001\\003\\006\\004\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -528,7 +541,7 @@ " name: \"value\"\n", " t {\n", " data_type: 1\n", - " raw_data: \"\\263-\\341<\"\n", + " raw_data: \"\\2556\\341<\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -542,7 +555,7 @@ " t {\n", " dims: 128\n", " data_type: 6\n", - " raw_data: \"\\271\\377\\377\\377\\032\\003\\000\\0009\\001\\000\\000\\302\\002\\000\\000;\\375\\377\\377\\031\\000\\000\\000\\024\\003\\000\\000d\\003\\000\\000\\327\\374\\377\\377\\363\\377\\377\\377u\\003\\000\\000\\374\\000\\000\\000t\\000\\000\\000\\321\\002\\000\\000\\236\\377\\377\\377\\241\\377\\377\\377\\237\\375\\377\\377\\010\\000\\000\\000\\350\\002\\000\\000}\\376\\377\\377\\267\\377\\377\\377\\374\\000\\000\\000\\355\\001\\000\\000N\\375\\377\\377\\\\\\002\\000\\000\\346\\002\\000\\000\\317\\000\\000\\000\\207\\001\\000\\000?\\000\\000\\000\\302\\002\\000\\000Y\\377\\377\\377\\326\\376\\377\\377\\\\\\003\\000\\000\\374\\376\\377\\377\\334\\000\\000\\000\\200\\001\\000\\000\\362\\377\\377\\377+\\000\\000\\000\\304\\375\\377\\377u\\000\\000\\000\\340\\000\\000\\000\\275\\001\\000\\000\\324\\377\\377\\377\\332\\000\\000\\000\\026\\001\\000\\000\\333\\001\\000\\000\\371\\375\\377\\377\\363\\000\\000\\000|\\002\\000\\000\\335\\376\\377\\377\\226\\375\\377\\377\\335\\002\\000\\0002\\001\\000\\000F\\377\\377\\377\\006\\003\\000\\000\\310\\375\\377\\377\\344\\377\\377\\377\\177\\376\\377\\377>\\001\\000\\000\\033\\002\\000\\000I\\003\\000\\000\\006\\376\\377\\377\\315\\375\\377\\377\\033\\003\\000\\000\\236\\000\\000\\000@\\376\\377\\377\\031\\002\\000\\000\\321\\002\\000\\000;\\000\\000\\000\\035\\377\\377\\377\\354\\377\\377\\377Z\\001\\000\\000N\\375\\377\\377I\\001\\000\\000\\030\\001\\000\\000w\\377\\377\\377\\303\\002\\000\\000\\022\\000\\000\\000\\377\\001\\000\\000!\\000\\000\\000\\035\\001\\000\\000\\003\\375\\377\\377^\\377\\377\\377\\336\\374\\377\\377p\\377\\377\\377\\351\\002\\000\\000X\\376\\377\\377\\247\\000\\000\\000H\\376\\377\\377}\\000\\000\\000\\225\\374\\377\\3776\\001\\000\\000\\301\\001\\000\\000\\210\\001\\000\\000\\374\\376\\377\\377\\307\\377\\377\\377\\320\\374\\377\\377\\267\\377\\377\\377F\\375\\377\\377\\352\\377\\377\\377=\\377\\377\\3770\\376\\377\\377#\\000\\000\\000\\313\\376\\377\\377\\334\\000\\000\\000\\261\\001\\000\\000\\363\\001\\000\\000\\037\\001\\000\\000\\220\\377\\377\\377\\202\\000\\000\\000d\\377\\377\\377\\013\\002\\000\\000\\266\\002\\000\\000\\347\\374\\377\\377+\\001\\000\\000\\301\\376\\377\\377\\341\\377\\377\\377O\\003\\000\\000\\037\\375\\377\\377\\244\\375\\377\\377\\352\\000\\000\\000\\302\\001\\000\\000I\\002\\000\\000~\\377\\377\\377*\\376\\377\\377\\333\\000\\000\\000\\214\\000\\000\\000\\014\\002\\000\\000\"\n", + " raw_data: \"\\016\\003\\000\\000\\240\\375\\377\\377\\344\\002\\000\\000\\341\\002\\000\\000\\207\\000\\000\\000C\\377\\377\\377\\255\\375\\377\\377,\\376\\377\\377\\\"\\001\\000\\000\\237\\001\\000\\000\\'\\003\\000\\000\\220\\377\\377\\377{\\003\\000\\000\\252\\002\\000\\000A\\003\\000\\000\\233\\002\\000\\000\\375\\377\\377\\377\\302\\001\\000\\000\\365\\374\\377\\377\\025\\003\\000\\000w\\003\\000\\000\\231\\375\\377\\377\\030\\377\\377\\377K\\000\\000\\000h\\002\\000\\0002\\001\\000\\000X\\003\\000\\000\\241\\001\\000\\000W\\000\\000\\000\\010\\002\\000\\000R\\002\\000\\000v\\003\\000\\000\\353\\001\\000\\000J\\001\\000\\000\\312\\377\\377\\377\\007\\002\\000\\000\\345\\376\\377\\377\\316\\001\\000\\000\\352\\000\\000\\000\\357\\375\\377\\377\\004\\001\\000\\000\\353\\002\\000\\000\\342\\376\\377\\377#\\003\\000\\000\\252\\001\\000\\000\\354\\377\\377\\377Y\\003\\000\\000x\\000\\000\\000\\251\\377\\377\\377f\\000\\000\\000}\\003\\000\\000\\317\\374\\377\\377\\300\\376\\377\\377\\230\\000\\000\\000c\\003\\000\\000\\204\\377\\377\\377n\\376\\377\\377(\\375\\377\\377\\314\\001\\000\\000\\304\\000\\000\\000\\357\\374\\377\\377\\241\\376\\377\\377\\217\\375\\377\\377\\r\\001\\000\\000;\\001\\000\\000\\240\\377\\377\\377Q\\377\\377\\377U\\375\\377\\377\\'\\377\\377\\377h\\002\\000\\000f\\002\\000\\000\\307\\002\\000\\000\\364\\001\\000\\000\\303\\000\\000\\000W\\000\\000\\000\\001\\375\\377\\377!\\002\\000\\000\\210\\000\\000\\000:\\377\\377\\377\\242\\000\\000\\000o\\377\\377\\377\\327\\001\\000\\000\\263\\377\\377\\377X\\003\\000\\000\\303\\000\\000\\0003\\000\\000\\000\\337\\375\\377\\377=\\375\\377\\377{\\000\\000\\000\\336\\375\\377\\377I\\375\\377\\377\\036\\377\\377\\377\\016\\002\\000\\000\\017\\003\\000\\000\\240\\374\\377\\377f\\001\\000\\000\\003\\375\\377\\377\\020\\375\\377\\377\\224\\000\\000\\000S\\375\\377\\377\\266\\001\\000\\000\\337\\002\\000\\000\\356\\375\\377\\377\\027\\000\\000\\000\\340\\000\\000\\000p\\377\\377\\377\\371\\000\\000\\000\\253\\000\\000\\000\\322\\001\\000\\0008\\001\\000\\000\\346\\377\\377\\377\\271\\374\\377\\377\\003\\376\\377\\377\\032\\001\\000\\000\\207\\374\\377\\377\\265\\374\\377\\377;\\003\\000\\000\\353\\374\\377\\377\\275\\002\\000\\000:\\002\\000\\000}\\377\\377\\377\\000\\376\\377\\377W\\376\\377\\377\\235\\000\\000\\000\\333\\376\\377\\377x\\375\\377\\377+\\376\\377\\377\\025\\003\\000\\000\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -600,9 +613,9 @@ " name: \"/linear/export_handler/DequantizeLinear\"\n", " op_type: \"DequantizeLinear\"\n", " }\n", - " name: \"torch_jit\"\n", + " name: \"main_graph\"\n", " input {\n", - " name: \"inp.1\"\n", + " name: \"out.1\"\n", " type {\n", " tensor_type {\n", " elem_type: 1\n", @@ -652,7 +665,7 @@ "}" ] }, - "execution_count": 8, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -691,7 +704,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 22, "metadata": { "collapsed": false, "pycharm": { @@ -724,10 +737,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -792,14 +805,22 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 23, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 09:51:08.764823876 [W:onnxruntime:, graph.cc:1283 Graph] Initializer linear.bias appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "\n", @@ -862,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 24, "metadata": { "collapsed": false, "pycharm": { @@ -910,7 +931,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 25, "metadata": { "collapsed": false, "pycharm": { @@ -922,7 +943,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" ] } @@ -997,7 +1018,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { diff --git a/notebooks/quantized_recurrent.ipynb b/notebooks/quantized_recurrent.ipynb index 766e82745..5bb95a465 100644 --- a/notebooks/quantized_recurrent.ipynb +++ b/notebooks/quantized_recurrent.ipynb @@ -38,12 +38,14 @@ " bias: bool = True,\n", " batch_first: bool = False,\n", " bidirectional: bool = False,\n", - " weight_quant = Int8WeightPerTensorFloat,\n", - " bias_quant = Int32Bias,\n", - " io_quant = Int8ActPerTensorFloat,\n", - " gate_acc_quant = Int8ActPerTensorFloat,\n", - " shared_input_hidden_weights = False,\n", + " weight_quant=Int8WeightPerTensorFloat,\n", + " bias_quant=Int32Bias,\n", + " io_quant=Int8ActPerTensorFloat,\n", + " gate_acc_quant=Int8ActPerTensorFloat,\n", + " shared_input_hidden_weights=False,\n", " return_quant_tensor: bool = False,\n", + " dtype: Optional[torch.dtype] = None,\n", + " device: Optional[torch.device] = None,\n", " **kwargs):\n", " super(QuantRNN, self).__init__(\n", " layer_impl=_QuantRNNLayer,\n", @@ -60,6 +62,8 @@ " gate_acc_quant=gate_acc_quant,\n", " shared_input_hidden_weights=shared_input_hidden_weights,\n", " return_quant_tensor=return_quant_tensor,\n", + " dtype=dtype,\n", + " device=device,\n", " **kwargs)\n", "\n", "```" @@ -107,7 +111,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -278,46 +282,46 @@ "Input-hidden weight bit-width: 4.0\n", "Hidden-hidden weight bit-width: 4.0\n", "I/O quant bit-width: 6.0\n", - "Input-hidden weight scale: tensor([[0.0316],\n", - " [0.0317],\n", - " [0.0319],\n", - " [0.0318],\n", - " [0.0314],\n", + "Input-hidden weight scale: tensor([[0.0297],\n", + " [0.0311],\n", " [0.0298],\n", + " [0.0295],\n", + " [0.0316],\n", + " [0.0311],\n", + " [0.0318],\n", + " [0.0309],\n", " [0.0317],\n", - " [0.0285],\n", - " [0.0306],\n", - " [0.0312],\n", + " [0.0309],\n", + " [0.0316],\n", + " [0.0319],\n", + " [0.0319],\n", " [0.0318],\n", " [0.0315],\n", - " [0.0298],\n", - " [0.0314],\n", - " [0.0293],\n", " [0.0310],\n", - " [0.0306],\n", - " [0.0310],\n", - " [0.0309],\n", - " [0.0317]], grad_fn=)\n", - "Hidden-hidden weight scale: tensor([[0.0316],\n", - " [0.0317],\n", + " [0.0319],\n", " [0.0319],\n", " [0.0318],\n", - " [0.0314],\n", + " [0.0312]], grad_fn=)\n", + "Hidden-hidden weight scale: tensor([[0.0297],\n", + " [0.0311],\n", " [0.0298],\n", + " [0.0295],\n", + " [0.0316],\n", + " [0.0311],\n", + " [0.0318],\n", + " [0.0309],\n", " [0.0317],\n", - " [0.0285],\n", - " [0.0306],\n", - " [0.0312],\n", + " [0.0309],\n", + " [0.0316],\n", + " [0.0319],\n", + " [0.0319],\n", " [0.0318],\n", " [0.0315],\n", - " [0.0298],\n", - " [0.0314],\n", - " [0.0293],\n", - " [0.0310],\n", - " [0.0306],\n", " [0.0310],\n", - " [0.0309],\n", - " [0.0317]], grad_fn=)\n" + " [0.0319],\n", + " [0.0319],\n", + " [0.0318],\n", + " [0.0312]], grad_fn=)\n" ] } ], @@ -387,52 +391,54 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:343: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\utils\\python_arg_parser.cpp:354.)\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py:84: UserWarning: Empty QuantTensor are deprecated and will be removed in a future version\n", + " warnings.warn(\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:320: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " return torch.cat(outputs, dim=seq_dim)\n" ] }, { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[-0.4458, -0.1651, -0.7045, -0.5889, -0.2532, -0.0330, -0.1651,\n", - " 0.1706, 0.1376, 0.4348, 0.5834, -0.3577, -0.2807, 0.1046,\n", - " 0.2532, 0.2807, 0.2532, -0.4293, 0.1376, -0.1486],\n", - " [-0.1569, 0.3530, -0.6995, -0.0458, -0.5295, -0.3007, -0.7257,\n", - " 0.2877, -0.1308, 0.6603, 0.0196, -0.8237, 0.0065, -0.4380,\n", - " -0.2615, 0.3138, -0.0850, 0.0065, 0.0458, -0.1961],\n", - " [ 0.1929, -0.5981, -0.2508, -0.2251, -0.5917, 0.2251, 0.0257,\n", - " 0.2508, -0.3023, 0.2830, 0.3344, -0.4309, -0.0836, 0.2701,\n", - " 0.3666, -0.1351, 0.1736, -0.0257, 0.1286, -0.6174],\n", - " [ 0.4682, -0.1804, 0.2780, 0.4974, 0.4389, -0.0585, -0.6242,\n", - " -0.0098, 0.2341, 0.3511, -0.2926, -0.4925, 0.1414, -0.4633,\n", - " -0.0683, 0.2633, 0.3804, 0.3024, 0.1951, 0.1707],\n", - " [-0.0852, 0.0965, -0.4656, -0.3180, -0.3464, -0.2782, -0.1931,\n", - " -0.6360, -0.3180, -0.3293, 0.7211, 0.4316, 0.4145, -0.3066,\n", - " -0.5224, -0.3066, -0.5849, -0.7211, 0.3293, 0.1420]],\n", + "(QuantTensor(value=tensor([[[-0.0062, -0.2872, 0.7931, 0.4309, 0.5495, -0.4558, 0.2373,\n", + " 0.6807, 0.4621, 0.6120, -0.1124, 0.3872, 0.3060, 0.7681,\n", + " -0.3684, 0.0437, -0.7369, -0.3247, 0.7743, 0.3372],\n", + " [ 0.5450, 0.2962, -0.3969, 0.3555, -0.5628, 0.2429, -0.4976,\n", + " 0.1777, -0.1244, 0.0296, -0.2607, 0.0948, 0.5036, -0.3673,\n", + " 0.5213, -0.2962, 0.7524, 0.0770, -0.0948, -0.0948],\n", + " [ 0.2691, -0.6624, -0.5434, 0.4968, -0.6624, 0.0983, 0.1345,\n", + " 0.1242, -0.0517, -0.3726, 0.3053, 0.1604, 0.3208, 0.0983,\n", + " 0.3105, 0.4243, 0.2794, 0.1604, 0.1035, -0.0724],\n", + " [ 0.1284, -0.3337, -0.5263, -0.0449, -0.5263, 0.3081, -0.1733,\n", + " 0.5648, 0.4942, -0.1412, 0.1733, 0.3337, 0.6225, 0.3401,\n", + " 0.5070, -0.1412, 0.0642, -0.3722, 0.2888, 0.1155],\n", + " [ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,\n", + " 0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,\n", + " 0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896]],\n", " \n", - " [[ 0.5669, 0.2367, -0.3027, -0.3137, -0.3632, -0.1651, -0.5999,\n", - " 0.2036, 0.4293, 0.2201, -0.2862, -0.3908, -0.2091, -0.2532,\n", - " -0.2532, -0.5834, -0.2697, 0.0055, 0.2532, 0.1761],\n", - " [ 0.1242, 0.4184, -0.6472, -0.0196, -0.4707, -0.5034, -0.8368,\n", - " 0.3530, 0.1504, 0.0458, -0.0654, -0.7714, -0.1961, -0.4903,\n", - " -0.6015, -0.3596, -0.2484, -0.4380, -0.0458, 0.2942],\n", - " [ 0.3409, 0.8168, -0.7396, 0.2958, 0.2508, -0.1286, -0.1286,\n", - " 0.7782, -0.1994, 0.7846, -0.3087, -0.3666, 0.1029, 0.1479,\n", - " -0.3216, -0.1479, -0.2315, 0.4566, 0.5209, -0.3344],\n", - " [-0.0878, 0.0390, -0.1707, -0.1365, -0.2243, -0.2390, -0.3706,\n", - " 0.1609, -0.5511, -0.4096, 0.5121, -0.5901, 0.2633, -0.3609,\n", - " -0.5511, 0.3755, -0.4925, -0.0293, -0.0780, -0.2829],\n", - " [ 0.0965, -0.1987, 0.0057, 0.1306, 0.3861, 0.2839, -0.3861,\n", - " 0.5962, -0.1987, 0.3180, -0.1647, -0.3066, -0.0227, 0.4372,\n", - " 0.0852, 0.3748, 0.0852, -0.0057, -0.1703, -0.0738]]],\n", - " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[-0.0852, 0.0965, -0.4656, -0.3180, -0.3464, -0.2782, -0.1931,\n", - " -0.6360, -0.3180, -0.3293, 0.7211, 0.4316, 0.4145, -0.3066,\n", - " -0.5224, -0.3066, -0.5849, -0.7211, 0.3293, 0.1420],\n", - " [ 0.0965, -0.1987, 0.0057, 0.1306, 0.3861, 0.2839, -0.3861,\n", - " 0.5962, -0.1987, 0.3180, -0.1647, -0.3066, -0.0227, 0.4372,\n", - " 0.0852, 0.3748, 0.0852, -0.0057, -0.1703, -0.0738]]],\n", - " grad_fn=), scale=tensor(0.0057, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " [[ 0.1374, 0.5745, 0.0624, -0.2373, 0.3060, 0.3310, -0.5183,\n", + " 0.1186, 0.1124, 0.2997, 0.0375, 0.6369, -0.5308, 0.6307,\n", + " -0.5683, 0.7556, 0.2997, -0.4933, 0.3934, -0.4871],\n", + " [ 0.1066, -0.1244, -0.1718, 0.4266, 0.5569, 0.0178, 0.1185,\n", + " -0.3910, 0.2133, 0.0178, -0.1066, -0.2903, 0.1837, -0.2547,\n", + " -0.2903, 0.0770, 0.3495, 0.2547, 0.2311, -0.6161],\n", + " [-0.0880, -0.1966, 0.3001, -0.0569, 0.4140, -0.1552, -0.1345,\n", + " 0.4554, 0.5175, 0.1242, -0.2898, 0.1966, -0.0414, 0.3985,\n", + " -0.1708, -0.0621, -0.1708, 0.0828, 0.2225, 0.0517],\n", + " [ 0.2118, 0.5648, -0.2824, -0.0449, 0.5840, 0.3209, -0.5648,\n", + " 0.3530, 0.4043, -0.4942, -0.3786, 0.0257, 0.5327, -0.1990,\n", + " -0.1348, -0.8215, 0.3016, 0.5327, 0.5648, -0.1155],\n", + " [-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,\n", + " -0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,\n", + " -0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],\n", + " grad_fn=), scale=tensor(0.0059, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", + " QuantTensor(value=tensor([[[ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,\n", + " 0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,\n", + " 0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896],\n", + " [-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,\n", + " -0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,\n", + " -0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],\n", + " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 10, @@ -461,48 +467,56 @@ "execution_count": 11, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[ 0.1760, 0.2670, -0.1214, -0.3702, 0.3884, 0.4127, 0.0243,\n", - " 0.0425, -0.2246, -0.0910, -0.2670, 0.4734, 0.0971, -0.3824,\n", - " 0.1396, 0.6858, 0.0061, 0.3702, 0.1275, 0.5037],\n", - " [ 0.2831, 0.0566, -0.2831, -0.2661, -0.0793, 0.3511, -0.4926,\n", - " 0.0510, -0.6455, 0.7191, -0.1812, -0.6172, 0.1529, 0.4077,\n", - " -0.7078, -0.0453, -0.0963, 0.4926, -0.4983, -0.4077],\n", - " [ 0.0000, -0.3977, 0.0947, 0.1894, -0.3725, -0.2589, -0.3914,\n", - " 0.3409, -0.0063, 0.2652, -0.5177, -0.4230, -0.0821, -0.0631,\n", - " 0.0505, -0.0189, 0.0253, -0.1578, -0.4988, 0.5556],\n", - " [ 0.4809, 0.8144, -0.6925, 0.4360, 0.0256, -0.4360, -0.5130,\n", - " 0.2501, -0.1347, 0.7631, -0.5386, -0.2437, 0.4296, -0.1988,\n", - " -0.7246, -0.1154, -0.2437, 0.3655, 0.0641, 0.3142],\n", - " [ 0.0706, -0.0192, -0.7185, -0.8211, -0.5709, 0.1155, 0.4683,\n", - " 0.3400, -0.3015, 0.3528, 0.3143, -0.1155, -0.3143, -0.0257,\n", - " 0.1411, -0.2309, 0.5132, 0.3721, 0.5196, -0.5453]],\n", + "(QuantTensor(value=tensor([[[ 0.2111, 0.1267, 0.0060, 0.6153, -0.7721, -0.3740, -0.5188,\n", + " 0.6273, 0.4162, 0.2051, 0.2292, 0.7239, 0.6032, 0.2533,\n", + " 0.5067, 0.6635, 0.1206, -0.5730, 0.0483, 0.3318],\n", + " [ 0.5742, 0.0194, -0.3807, -0.0710, -0.6000, 0.1807, 0.1355,\n", + " 0.4129, 0.3807, 0.3936, -0.0903, 0.1549, 0.1032, 0.0645,\n", + " 0.4775, -0.0645, 0.1161, -0.0065, 0.0194, -0.1097],\n", + " [ 0.0453, -0.4533, 0.1036, -0.0194, -0.2979, 0.3432, 0.0777,\n", + " 0.6346, -0.0842, 0.3302, 0.4727, 0.4856, -0.4144, 0.7382,\n", + " -0.0453, 0.5439, 0.2266, -0.4792, 0.4403, -0.1036],\n", + " [ 0.3198, 0.2741, -0.6395, 0.0971, -0.6052, -0.5196, 0.1770,\n", + " -0.5025, -0.1256, 0.2056, 0.2684, -0.6395, -0.0285, -0.7309,\n", + " 0.7194, -0.7194, 0.1542, -0.3426, -0.6509, 0.0343],\n", + " [ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,\n", + " 0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,\n", + " -0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197]],\n", " \n", - " [[ 0.4066, -0.7768, 0.6008, 0.0546, 0.0182, 0.1821, 0.0971,\n", - " -0.3763, 0.3520, -0.5037, -0.0061, 0.2246, -0.0486, 0.2124,\n", - " 0.3641, -0.6433, 0.4248, 0.0789, 0.1275, -0.1214],\n", - " [ 0.2321, 0.1982, -0.1302, 0.1529, -0.0736, -0.3567, -0.4360,\n", - " -0.0283, 0.4869, 0.5379, -0.6964, -0.0340, -0.2944, -0.1529,\n", - " -0.2152, -0.4643, 0.3454, 0.3284, -0.3341, 0.5945],\n", - " [-0.2020, 0.0379, -0.8081, -0.7260, -0.0821, 0.0631, 0.4988,\n", - " 0.0694, 0.0253, 0.5430, 0.8018, 0.2273, -0.3472, -0.0505,\n", - " 0.4924, -0.4735, 0.5745, -0.5619, 0.6313, -0.1768],\n", - " [ 0.2501, -0.4360, 0.6541, 0.0385, 0.5835, -0.3078, -0.0449,\n", - " 0.3270, 0.7951, -0.3591, -0.4809, -0.2757, -0.3591, -0.7567,\n", - " 0.5194, 0.2757, 0.7438, 0.7695, 0.5451, 0.4296],\n", - " [ 0.2630, -0.4747, 0.1347, -0.0641, -0.2245, -0.3336, -0.4490,\n", - " -0.4619, -0.1796, -0.5517, 0.3913, 0.0257, -0.2053, -0.2823,\n", - " -0.6992, -0.6607, 0.1989, -0.6928, -0.5581, 0.5966]]],\n", + " [[ 0.2111, -0.2111, -0.3197, -0.0241, -0.5067, -0.0241, -0.2895,\n", + " 0.1749, -0.4283, 0.0000, -0.3680, 0.5308, -0.1267, 0.5248,\n", + " 0.1206, 0.2654, 0.6394, -0.1327, 0.2292, -0.3800],\n", + " [ 0.6775, -0.3355, -0.1807, 0.2774, -0.8259, -0.2000, -0.0065,\n", + " 0.5678, 0.4000, 0.2258, 0.4387, 0.2710, 0.5355, 0.1290,\n", + " 0.6710, -0.0645, -0.2710, -0.3613, 0.6388, 0.5226],\n", + " [-0.0065, -0.0777, -0.6475, -0.1684, -0.3820, 0.3885, 0.0065,\n", + " 0.1943, -0.3238, -0.2525, -0.1230, -0.0453, -0.0777, 0.3432,\n", + " 0.4921, -0.1101, 0.8224, 0.2396, 0.1554, -0.3885],\n", + " [-0.0514, -0.4111, -0.4625, -0.1713, -0.3369, 0.2512, -0.2969,\n", + " -0.4111, -0.2341, 0.3597, -0.1998, 0.0000, 0.2741, 0.7137,\n", + " -0.1256, 0.1370, -0.0742, -0.5938, -0.5424, -0.4168],\n", + " [ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,\n", + " 0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,\n", + " 0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],\n", " grad_fn=), scale=tensor(0.0062, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[ 0.0706, -0.0192, -0.7185, -0.8211, -0.5709, 0.1155, 0.4683,\n", - " 0.3400, -0.3015, 0.3528, 0.3143, -0.1155, -0.3143, -0.0257,\n", - " 0.1411, -0.2309, 0.5132, 0.3721, 0.5196, -0.5453],\n", - " [ 0.2630, -0.4747, 0.1347, -0.0641, -0.2245, -0.3336, -0.4490,\n", - " -0.4619, -0.1796, -0.5517, 0.3913, 0.0257, -0.2053, -0.2823,\n", - " -0.6992, -0.6607, 0.1989, -0.6928, -0.5581, 0.5966]]],\n", - " grad_fn=), scale=tensor(0.0064, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " QuantTensor(value=tensor([[[ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,\n", + " 0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,\n", + " -0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197],\n", + " [ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,\n", + " 0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,\n", + " 0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],\n", + " grad_fn=), scale=tensor(0.0066, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 11, @@ -533,45 +547,45 @@ { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[-0.1984, 0.2499, -0.1102, 0.2499, -0.0955, -0.4630, -0.8672,\n", - " 0.1911, -0.4851, 0.8085, 0.6982, -0.5806, 0.0000, -0.4189,\n", - " -0.7423, -0.4851, -0.9260, -0.0147, 0.0514, -0.1984],\n", - " [-0.2167, 0.5092, -0.3846, 0.0650, 0.6717, -0.2492, -0.0867,\n", - " 0.3142, -0.3900, 0.3521, 0.4767, -0.1137, 0.6879, 0.1733,\n", - " -0.0596, 0.4279, -0.5471, -0.2762, 0.5904, -0.3737],\n", - " [-0.1335, -0.0140, -0.2810, -0.5339, -0.5339, 0.0562, 0.7236,\n", - " -0.1264, -0.0211, -0.3021, -0.1124, 0.4777, 0.3793, 0.2388,\n", - " -0.0702, 0.4847, -0.4988, 0.7236, 0.5901, -0.4847],\n", - " [ 0.3340, -0.5225, -0.1242, 0.1499, 0.3083, -0.1756, -0.1713,\n", - " 0.0000, 0.3512, -0.3041, 0.3126, -0.5482, 0.4882, 0.1028,\n", - " -0.4796, 0.1028, -0.2527, -0.3640, 0.1713, 0.0471],\n", - " [-0.4438, -0.2686, -0.3095, -0.2978, -0.0993, 0.0584, 0.4846,\n", - " -0.0526, 0.3737, -0.4496, 0.1109, 0.7416, -0.0526, 0.3445,\n", - " 0.4963, 0.2803, 0.1927, 0.0000, 0.6131, 0.1109]],\n", + "(QuantTensor(value=tensor([[[-0.3777, -0.2074, 0.7184, 0.9110, 0.0148, -0.1926, -0.7110,\n", + " 0.1926, -0.4222, -0.9480, 0.2592, 0.2222, -0.2370, -0.5407,\n", + " 0.5851, -0.2370, 0.3555, 0.1703, 0.4444, -0.2222],\n", + " [ 0.4814, -0.7355, -0.1605, 0.3878, -0.5282, 0.2073, 0.0000,\n", + " 0.3677, 0.1805, -0.1204, -0.4614, 0.2474, 0.7021, 0.0401,\n", + " 0.4346, 0.4480, -0.3143, 0.0401, 0.6887, 0.6753],\n", + " [ 0.5038, -0.3650, -0.6936, 0.0146, -0.9345, 0.0000, 0.1679,\n", + " -0.3066, 0.1825, 0.4089, 0.0949, -0.2555, 0.3870, -0.2482,\n", + " 0.5914, -0.0803, 0.1314, -0.4235, -0.3797, 0.1168],\n", + " [ 0.1795, 0.1795, 0.0449, 0.0449, 0.2308, 0.0898, -0.1282,\n", + " 0.5579, 0.1731, -0.1795, 0.1603, 0.3142, 0.1090, 0.5835,\n", + " -0.1475, 0.0449, 0.1795, -0.0256, 0.8143, -0.2437],\n", + " [-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,\n", + " 0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,\n", + " -0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316]],\n", " \n", - " [[ 0.1102, -0.8085, 0.5806, -0.0661, 0.3013, 0.2646, 0.2499,\n", - " -0.6321, 0.4557, 0.4777, 0.6321, 0.0294, -0.2646, -0.9407,\n", - " 0.7350, -0.6027, 0.6174, -0.4116, 0.6835, 0.0514],\n", - " [ 0.1787, 0.0271, 0.1354, -0.3033, 0.6229, -0.3250, -0.3846,\n", - " 0.0812, 0.5633, 0.6879, -0.0325, -0.2383, -0.3521, -0.5850,\n", - " 0.3033, -0.3900, 0.6771, 0.3196, 0.5633, 0.2383],\n", - " [-0.1264, 0.5901, -0.3934, 0.3231, 0.0492, -0.5128, -0.8149,\n", - " 0.1124, -0.7517, 0.8711, 0.4004, -0.8992, 0.0702, -0.2178,\n", - " -0.8851, -0.5760, -0.1054, -0.0702, -0.3512, -0.5198],\n", - " [ 0.2612, 0.2570, 0.1542, -0.1071, -0.0300, 0.0257, -0.3854,\n", - " -0.0685, -0.2570, 0.0728, -0.4240, -0.3083, 0.1627, -0.3383,\n", - " -0.0428, 0.0300, -0.1199, 0.3683, 0.3298, -0.3340],\n", - " [ 0.4204, -0.2452, -0.0934, 0.2336, 0.1285, -0.1285, 0.2044,\n", - " -0.0701, 0.0058, 0.3971, 0.0175, -0.3270, 0.2803, 0.1810,\n", - " -0.4963, -0.5547, 0.0467, 0.0175, 0.1927, -0.2452]]],\n", - " grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[-0.4438, -0.2686, -0.3095, -0.2978, -0.0993, 0.0584, 0.4846,\n", - " -0.0526, 0.3737, -0.4496, 0.1109, 0.7416, -0.0526, 0.3445,\n", - " 0.4963, 0.2803, 0.1927, 0.0000, 0.6131, 0.1109],\n", - " [ 0.4204, -0.2452, -0.0934, 0.2336, 0.1285, -0.1285, 0.2044,\n", - " -0.0701, 0.0058, 0.3971, 0.0175, -0.3270, 0.2803, 0.1810,\n", - " -0.4963, -0.5547, 0.0467, 0.0175, 0.1927, -0.2452]]],\n", - " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " [[ 0.5110, -0.3555, 0.6443, -0.8221, 0.4888, -0.2074, 0.0444,\n", + " 0.4888, 0.5999, 0.4370, 0.0000, 0.5036, -0.7628, 0.9332,\n", + " -0.6147, 0.7332, 0.3629, 0.9184, 0.7702, -0.8887],\n", + " [ 0.8492, -0.3410, -0.3878, 0.1404, -0.3410, 0.3143, -0.1204,\n", + " 0.5817, 0.4413, 0.5550, 0.6486, -0.1070, 0.6285, -0.4948,\n", + " 0.2006, 0.1605, 0.0535, -0.4079, 0.3811, 0.4948],\n", + " [ 0.6060, 0.7666, -0.8688, -0.6863, -0.5111, -0.0803, -0.6425,\n", + " -0.0146, -0.3577, 0.3431, -0.6571, 0.5622, 0.0000, 0.7374,\n", + " -0.1314, -0.3650, 0.7520, 0.2336, -0.2847, -0.8250],\n", + " [ 0.3014, 0.2950, -0.0898, -0.3142, 0.4040, 0.4681, -0.0705,\n", + " -0.2052, 0.8143, -0.1603, 0.3334, -0.6733, 0.0834, 0.0898,\n", + " -0.4937, 0.1924, 0.0064, 0.4104, 0.6348, -0.3527],\n", + " [-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,\n", + " -0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,\n", + " -0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],\n", + " grad_fn=), scale=tensor(0.0069, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", + " QuantTensor(value=tensor([[[-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,\n", + " 0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,\n", + " -0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316],\n", + " [-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,\n", + " -0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,\n", + " -0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],\n", + " grad_fn=), scale=tensor(0.0066, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 12, @@ -754,20 +768,22 @@ " bias: bool = True,\n", " batch_first: bool = False,\n", " bidirectional: bool = False,\n", - " weight_quant = Int8WeightPerTensorFloat,\n", - " bias_quant = Int32Bias,\n", - " io_quant = Int8ActPerTensorFloat,\n", - " gate_acc_quant = Int8ActPerTensorFloat,\n", - " sigmoid_quant = Uint8ActPerTensorFloat,\n", - " tanh_quant = Int8ActPerTensorFloat,\n", - " cell_state_quant = Int8ActPerTensorFloat,\n", + " weight_quant=Int8WeightPerTensorFloat,\n", + " bias_quant=Int32Bias,\n", + " io_quant=Int8ActPerTensorFloat,\n", + " gate_acc_quant=Int8ActPerTensorFloat,\n", + " sigmoid_quant=Uint8ActPerTensorFloat,\n", + " tanh_quant=Int8ActPerTensorFloat,\n", + " cell_state_quant=Int8ActPerTensorFloat,\n", " coupled_input_forget_gates: bool = False,\n", - " cat_output_cell_states = True,\n", - " shared_input_hidden_weights = False,\n", - " shared_intra_layer_weight_quant = False,\n", - " shared_intra_layer_gate_acc_quant = False,\n", - " shared_cell_state_quant = True,\n", + " cat_output_cell_states=True,\n", + " shared_input_hidden_weights=False,\n", + " shared_intra_layer_weight_quant=False,\n", + " shared_intra_layer_gate_acc_quant=False,\n", + " shared_cell_state_quant=True,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs):\n", " super(QuantLSTM, self).__init__(\n", " layer_impl=_QuantLSTMLayer,\n", @@ -790,6 +806,8 @@ " shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,\n", " shared_cell_state_quant=shared_cell_state_quant,\n", " return_quant_tensor=return_quant_tensor,\n", + " dtype=dtype,\n", + " device=device,\n", " **kwargs)\n", " if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:\n", " raise RuntimeError(\"Concatenating cell states requires shared cell quantizers.\")\n", @@ -894,7 +912,16 @@ "cell_type": "code", "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import torch\n", "from brevitas.nn import QuantLSTM\n", @@ -936,7 +963,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": 19, @@ -958,9 +985,17 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 10:22:46.461627098 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_93 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", @@ -1004,37 +1039,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8082)" ] @@ -1049,9 +1054,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 10:22:49.697482752 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_87 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", @@ -1079,7 +1092,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:77: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -1104,37 +1117,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_bidirectional_2_layers.onnx' at http://localhost:8083\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8083)" ] @@ -1155,7 +1138,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:77: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -1180,37 +1163,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8085)" ] @@ -1225,17 +1178,39 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] - }, + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", + " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,\n", + " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", + "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'\n", + "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [ { "name": "stdout", "output_type": "stream", @@ -1258,35 +1233,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 24, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_onnx_qcdq\n", - "\n", - "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", - " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,\n", - " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", - "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'\n", - "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8086)" ] @@ -1301,17 +1255,40 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] - }, + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", + " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, \n", + " shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,\n", + " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", + "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'\n", + "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [ { "name": "stdout", "output_type": "stream", @@ -1334,36 +1311,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 25, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_onnx_qcdq\n", - "\n", - "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", - " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, \n", - " shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,\n", - " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", - "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'\n", - "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8087)" ] @@ -1380,8 +1335,37 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 32, "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n", + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n", + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n" + ] + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_qonnx\n", + "\n", + "quant_lstm = QuantLSTM(input_size=10, hidden_size=20)\n", + "export_path = 'quant_lstm.onnx'\n", + "exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [ + "skip-execution" + ] + }, "outputs": [ { "name": "stdout", @@ -1405,33 +1389,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 26, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_qonnx\n", - "\n", - "quant_lstm = QuantLSTM(input_size=10, hidden_size=20)\n", - "export_path = 'quant_lstm.onnx'\n", - "exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8088)" ] @@ -1518,7 +1483,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5 (default, Oct 25 2019, 15:51:11) \n[GCC 7.3.0]" + "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 1410ebeb0..9b82927d6 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -181,7 +181,8 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): return inp def pack_output(self, quant_output: QuantTensor): - if not self.training and self.cache_inference_quant_out and isinstance(quant_output, QuantTensor): + if not self.training and self.cache_inference_quant_out and isinstance(quant_output, + QuantTensor): self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: @@ -291,7 +292,7 @@ def maybe_quantize_state(self, inp, state, quant): def pack_quant_outputs(self, quant_outputs): # In export mode, quant_outputs has the shape of the output concatenated value if self.export_mode: - if self.return_quant_tensor: + if self.return_quant_tensor and self.io_quant.is_quant_enabled: return QuantTensor( quant_outputs, self.io_quant.scale(), @@ -321,7 +322,7 @@ def pack_quant_outputs(self, quant_outputs): def pack_quant_state(self, quant_state, quant): if self.export_mode: - if self.return_quant_tensor: + if self.return_quant_tensor and quant.is_quant_enabled: quant_state = QuantTensor( torch.unsqueeze(quant_state, dim=0), quant.scale(), From f6559cb95f26a534cc5c4967452ef9a65257d043 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 10:39:38 +0000 Subject: [PATCH 05/32] Fix tests --- tests/brevitas/fx/test_tracer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/fx/test_tracer.py b/tests/brevitas/fx/test_tracer.py index be5698d2c..d7ef1d3ac 100644 --- a/tests/brevitas/fx/test_tracer.py +++ b/tests/brevitas/fx/test_tracer.py @@ -232,8 +232,8 @@ def test_module(module): @pytest.mark.parametrize('module', QUANT_TENSOR_MODULES) def test_quant_module(module): mod = module() - x = QuantTensor(torch.randn(INPUT_SIZE)) - x_trace = QuantTensor(torch.randn(INPUT_SIZE)) + x = torch.randn(INPUT_SIZE) + x_trace = torch.randn(INPUT_SIZE) with torch.no_grad(): out = mod(x) graph_model = value_trace(mod, value_args={'x': x_trace}) From 3ec1be27461c7aa259e52d32e8eb7eeb47414bf2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 10:40:12 +0000 Subject: [PATCH 06/32] Pre commit --- src/brevitas/nn/quant_layer.py | 3 ++- src/brevitas/quant_tensor/__init__.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index ee5b8e628..538d94852 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -349,7 +349,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe quant_bias_value = getattr(quant_bias, 'value', quant_bias) quant_bias_scale = getattr(quant_bias, 'scale', None) quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None) - if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): + if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, + QuantTensor): self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( return_value(quant_input), return_value(quant_weight), return_value(quant_bias)) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 9c80f6969..08dec65e0 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -92,8 +92,7 @@ def __new__( elif not _allow_empty and (scale is None or bit_width is None or zero_point is None): raise RuntimeError("To create an emtpy QuantTensor, set _allow_empty=True") - quant_tensor = super().__new__( - cls, value, scale, zero_point, bit_width, signed, training) + quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) return quant_tensor @property From 0fa9213b735519cd08cd3abc30ca9ce7c352be73 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 11:32:23 +0000 Subject: [PATCH 07/32] Fix return_quant_tensor during calibration --- src/brevitas/graph/calibrate.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 79cedff7f..8f690fc9b 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -48,6 +48,21 @@ BN_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +def disable_return_quant_tensor(model): + previous_state = {} + for module in model.modules(): + if hasattr(module, 'return_quant_tensor'): + previous_state[module] = module.return_quant_tensor + module.return_quant_tensor = False + return previous_state + + +def restore_return_quant_tensor(model, previous_state): + for module in model.modules(): + if hasattr(module, 'return_quant_tensor'): + module.return_quant_tensor = previous_state[module] + + def extend_collect_stats_steps(module): if hasattr(module, 'collect_stats_steps'): # We extend the collect steps in PTQ to match potentially long calibrations @@ -75,11 +90,13 @@ def __init__(self, model, enabled=True): self.previous_training_state = model.training self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True) self.enabled = enabled + self.return_quant_tensor_state = dict() def __enter__(self): if self.enabled: self.model.apply(extend_collect_stats_steps) self.model.apply(set_collect_stats_to_average) + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.apply( self.model, is_training=True, quantization_enabled=False) @@ -87,6 +104,7 @@ def __exit__(self, type, value, traceback): self.model.apply(finalize_collect_stats) self.disable_quant_inference.apply( self.model, is_training=self.previous_training_state, quantization_enabled=True) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) class load_quant_model: From 4eefe9be7962ffc582306c2e5e2b8774e1a4118e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 13:34:04 +0000 Subject: [PATCH 08/32] Fix for FlexmlAvgpool --- src/brevitas/nn/target/flexml.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/target/flexml.py b/src/brevitas/nn/target/flexml.py index e98438f75..66daa94cf 100644 --- a/src/brevitas/nn/target/flexml.py +++ b/src/brevitas/nn/target/flexml.py @@ -97,11 +97,13 @@ def _avg_scaling(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) - x = x.set(value=super(FlexMLQuantAvgPool2d, self).forward(x.value) * self.rescaling_const) - if x.scale is not None: + if isinstance(x, QuantTensor): + x = x.set( + value=super(FlexMLQuantAvgPool2d, self).forward(x.value) * self.rescaling_const) x = x.set(scale=x.scale * self.quantized_div_scale) - if x.bit_width is not None: x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) + else: + x = super(FlexMLQuantAvgPool2d, self).forward(x) * self.rescaling_const return self.pack_output(x) def max_acc_bit_width(self, input_bit_width): From 741bf948a9a4e6211fbe091fc3c30ad5bbc9d5d9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 14:09:31 +0000 Subject: [PATCH 09/32] Fix ORT tests --- tests/brevitas_ort/common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index e01596dc9..fa324f0be 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -27,6 +27,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat SEED = 123456 @@ -116,10 +117,15 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) - computed_out = brevitas_output.value + if isinstance(brevitas_output, QuantTensor): + computed_out = brevitas_output.value + scale = brevitas_output.scale + else: + computed_out = brevitas_output + scale = 1. if tolerance is not None and export_type == 'qcdq': - tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale + tolerance = tolerance * scale # Float Output, tolerance is +/- output scale if export_type == 'qonnx': exported_model = export_qonnx(model, input_t, export_path=export_name) From 1f02f7639909698fe8be32ba6237b92a59d417f3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 23:49:35 +0000 Subject: [PATCH 10/32] Partial review --- src/brevitas/nn/mixin/base.py | 21 ++++------- src/brevitas/nn/quant_layer.py | 67 ++++++++++++---------------------- src/brevitas/nn/quant_rnn.py | 8 ++-- 3 files changed, 35 insertions(+), 61 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 9b82927d6..047e7219b 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,7 +18,6 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape -from brevitas.quant_tensor import _is_all_nested_not_none from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs @@ -167,12 +166,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp - # else: - # if not self.training and self.cache_inference_quant_inp: - # cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - # self._cached_inp = cached_inp - # Remove any naming metadata to avoid dowmstream errors - # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) @@ -255,11 +248,11 @@ def gate_params_fwd(gate, quant_input): quant_weight_hh = gate.hidden_weight() if isinstance(quant_input, QuantTensor): acc_bit_width = None # TODO - if getattr(quant_input, 'scale', None) is not None and getattr( - quant_weight_ih, 'scale', None) is not None: - acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) - acc_scale = quant_weight_ih.scale.view(acc_scale_shape) - acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor): + if quant_input.scale is not None and quant_weight_ih.scale is not None: + acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) + acc_scale = quant_weight_ih.scale.view(acc_scale_shape) + acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width) return quant_weight_ih, quant_weight_hh, quant_bias @@ -303,7 +296,7 @@ def pack_quant_outputs(self, quant_outputs): else: return quant_outputs seq_dim = 1 if self.cell.batch_first else 0 - if self.return_quant_tensor: + if self.return_quant_tensor and self.io_quant.is_quant_enabled: outputs = [ QuantTensor( torch.unsqueeze(quant_output[0], dim=seq_dim), @@ -333,7 +326,7 @@ def pack_quant_state(self, quant_state, quant): else: quant_state = torch.unsqueeze(quant_state, dim=0) else: - if self.return_quant_tensor: + if self.return_quant_tensor and quant.is_quant_enabled: quant_state = QuantTensor( torch.unsqueeze(quant_state[0], dim=0), quant_state[1], diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 538d94852..9590f8e11 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -10,6 +10,7 @@ from torch.nn import Module from torch.nn import Parameter +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin import * @@ -19,10 +20,6 @@ from .utils import rename_state_dict_by_prefix -def return_value(tensor): - return tensor.value if isinstance(tensor, QuantTensor) else tensor - - class QuantNonLinearActLayer(QuantNonLinearActMixin, QuantInputMixin, QuantLayerMixin, Module): __metaclass__ = ABCMeta @@ -299,11 +296,6 @@ def merge_bn_in(self, bn): merge_bn(self, bn, output_channel_dim=self.output_channel_dim) def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - output_scale = None - output_bit_width = None - output_zero_point = None - output_signed = None - inp = self.unpack_input(inp) # shortcut execution through the export impl during export @@ -314,14 +306,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return out quant_input = self.input_quant(inp) - # quant_input_value = getattr(quant_input, 'value', quant_input) - # quant_input_scale = getattr(quant_input, 'scale', None) - # quant_input_bitwidth = getattr(quant_input, 'bit_width', None) - quant_weight = self.quant_weight(quant_input) - # quant_weight_value = getattr(quant_weight, 'value', quant_weight) - # quant_weight_scale = getattr(quant_weight, 'scale', None) - # quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None) + compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( quant_weight, QuantTensor) if not (compute_output_quant_tensor or @@ -337,42 +323,37 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_scale = self.quant_output_scale_impl( inp, quant_input.scale, quant_weight.scale) - - quant_input_signed = quant_input.signed if isinstance( - quant_input, QuantTensor) else True - quant_weight_signed = quant_weight.signed if isinstance( - quant_weight, QuantTensor) else True - output_signed = quant_input_signed or quant_weight_signed + output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) - quant_bias_value = getattr(quant_bias, 'value', quant_bias) - quant_bias_scale = getattr(quant_bias, 'scale', None) - quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None) if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( - return_value(quant_input), return_value(quant_weight), return_value(quant_bias)) - - if (self.return_quant_tensor and output_scale is not None and - (quant_bias_scale is None or - (quant_bias_scale is not None and - quant_bias_scale.data_ptr() != output_scale.data_ptr()))): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - inp, channel_dim=channel_dim) - output_zero_point = -quant_bias_value.view( - output_scale_broadcast_shape) / output_scale - - if hasattr(quant_bias, 'bit_width' - ) and quant_bias_bitwidth is not None and output_bit_width is not None: - output_bit_width = torch.where( - quant_bias_bitwidth > output_bit_width, quant_bias_bitwidth, output_bit_width) - output_bit_width = output_bit_width + 1 + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(quant_bias)) + + if (self.return_quant_tensor and isinstance(quant_bias, QuantTensor)): + + if output_scale is not None and quant_bias.scale.data_ptr( + ) != output_scale.data_ptr(): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -quant_bias.value.view( + output_scale_broadcast_shape) / output_scale + + if output_bit_width is not None: + output_bit_width = torch.where( + quant_bias.bit_width > output_bit_width, + quant_bias.bit_width, + output_bit_width) + output_bit_width = output_bit_width + 1 else: output_tensor = self.inner_forward_impl( - return_value(quant_input), return_value(quant_weight), None) + _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) if self.return_quant_tensor and not self.is_output_quant_enabled: if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 6e7ec581a..eed067cd6 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -684,19 +684,19 @@ def forward(self, inp, hidden_state, cell_state): quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input) # Handle None bias by setting it 0. - if getattr(quant_bias_input, 'value', quant_bias_input) is None: + if quant_bias_input is None: quant_bias_input = torch.tensor(0., device=quant_input_value.device) else: quant_bias_input = _unpack_quant_tensor(quant_bias_input) - if getattr(quant_bias_forget, 'value', quant_bias_forget) is None: + if quant_bias_forget is None: quant_bias_forget = torch.tensor(0., device=quant_input_value.device) else: quant_bias_forget = _unpack_quant_tensor(quant_bias_forget) - if getattr(quant_bias_cell, 'value', quant_bias_cell) is None: + if quant_bias_cell is None: quant_bias_cell = torch.tensor(0., device=quant_input_value.device) else: quant_bias_cell = _unpack_quant_tensor(quant_bias_cell) - if getattr(quant_bias_output, 'value', quant_bias_output) is None: + if quant_bias_output is None: quant_bias_output = torch.tensor(0., device=quant_input_value.device) else: quant_bias_output = _unpack_quant_tensor(quant_bias_output) From 7a60bea36c9d6cfe44209c1d4a4c58dd2cd3e314 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 00:55:16 +0000 Subject: [PATCH 11/32] Fix for output quant metadata --- src/brevitas/nn/quant_layer.py | 5 +++++ src/brevitas/nn/quant_rnn.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 9590f8e11..dcc9d6f58 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -296,6 +296,11 @@ def merge_bn_in(self, bn): merge_bn(self, bn, output_channel_dim=self.output_channel_dim) def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + output_scale = None + output_bit_width = None + output_signed = None + output_zero_point = None + inp = self.unpack_input(inp) # shortcut execution through the export impl during export diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index eed067cd6..accb456f3 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -419,7 +419,7 @@ def forward(self, inp, state): quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) quant_input_value = _unpack_quant_tensor(quant_input) - if getattr(quant_bias, 'value', quant_bias) is None: + if quant_bias is None: quant_bias = torch.tensor(0., device=quant_input_value.device) else: quant_bias = _unpack_quant_tensor(quant_bias) From 4ef70de5dc9354d6920e137929a226f85ee96076 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 02:16:17 +0000 Subject: [PATCH 12/32] Fix output zp quant layer --- src/brevitas/nn/quant_layer.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index dcc9d6f58..206ab9584 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -340,17 +340,17 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe _unpack_quant_tensor(quant_weight), _unpack_quant_tensor(quant_bias)) - if (self.return_quant_tensor and isinstance(quant_bias, QuantTensor)): - - if output_scale is not None and quant_bias.scale.data_ptr( - ) != output_scale.data_ptr(): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - inp, channel_dim=channel_dim) - output_zero_point = -quant_bias.value.view( - output_scale_broadcast_shape) / output_scale - - if output_bit_width is not None: + if (self.return_quant_tensor): + if output_scale is not None: + if (isinstance(quant_bias, QuantTensor) and quant_bias.scale.data_ptr() != + output_scale.data_ptr()) or not isinstance(quant_bias, QuantTensor): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(quant_bias).view( + output_scale_broadcast_shape) / output_scale + + if output_bit_width is not None and isinstance(quant_bias, QuantTensor): output_bit_width = torch.where( quant_bias.bit_width > output_bit_width, quant_bias.bit_width, From 36b521329408e5198fe104241f48015cddebc9f1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 16:26:18 +0000 Subject: [PATCH 13/32] Fix return lstm --- tests/brevitas/nn/test_nn_quantizers.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 7f2bdcd7d..b7d73ee18 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -121,23 +121,18 @@ def test_quant_lstm_rnn_full(model_input, current_cases): if h is not None: if return_quant_tensor and kwargs['io_quant'] is not None: assert isinstance(h, QuantTensor) + assert h.scale is not None + assert h.bit_width is not None else: assert isinstance(h, torch.Tensor) if c is not None: - if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']: - if not kwargs['bidirectional']: - if not kwargs['return_quant_tensor']: - assert isinstance(c, torch.Tensor) - elif kwargs['return_quant_tensor'] and kwargs['signed_act'] is None and kwargs[ - 'num_layers'] == 2: - assert isinstance(c, torch.Tensor) - else: - assert isinstance(c, QuantTensor) - else: - assert isinstance(c, torch.Tensor) + if kwargs['signed_act'] is None or not return_quant_tensor: + assert isinstance(c, torch.Tensor) else: assert isinstance(c, QuantTensor) + assert c.scale is not None + assert c.bit_width is not None @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn]) From 2a57e1906a49bdaf0a60ea011b27432c1ae89ce4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 20:04:26 +0000 Subject: [PATCH 14/32] Better quant lstm --- src/brevitas/nn/mixin/base.py | 7 +--- src/brevitas/nn/quant_rnn.py | 4 +- src/brevitas/quant_tensor/__init__.py | 35 ++++------------ tests/brevitas/nn/nn_quantizers_fixture.py | 9 +++- tests/brevitas/nn/test_nn_quantizers.py | 49 +++++++++------------- 5 files changed, 39 insertions(+), 65 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 047e7219b..7ef184de5 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -277,7 +277,6 @@ def maybe_quantize_state(self, inp, state, quant): batch_size = inp.size(0) if self.cell.batch_first else inp.size(1) quant_state = torch.zeros( int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device) - # quant_state = QuantTensor(quant_state) else: quant_state = quant(state) return quant_state @@ -304,8 +303,7 @@ def pack_quant_outputs(self, quant_outputs): quant_output[2], quant_output[3], self.io_quant.is_signed, - self.training, - _allow_empty=True) for quant_output in quant_outputs] + self.training) for quant_output in quant_outputs] else: outputs = [torch.unsqueeze(o[0], dim=seq_dim) for o in quant_outputs] if self.reverse_input: @@ -333,8 +331,7 @@ def pack_quant_state(self, quant_state, quant): quant_state[2], quant_state[3], quant.is_signed, - training=self.training, - _allow_empty=True) + training=self.training) else: quant_state = torch.unsqueeze(quant_state[0], dim=0) return quant_state diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index accb456f3..745aef3e4 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -24,7 +24,6 @@ from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat from brevitas.quant_tensor import _unpack_quant_tensor -from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -971,6 +970,9 @@ def __init__( **kwargs) if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant: raise RuntimeError("Concatenating cell states requires shared cell quantizers.") + if return_quant_tensor and (io_quant is None or cell_state_quant is None): + raise RuntimeError( + "To return a valid QuantTensor, specify a io_quant and cell_state_quant") self.cat_output_cell_states = cat_output_cell_states def forward(self, inp, hx=None, cx=None): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 08dec65e0..865211b8f 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -59,39 +59,18 @@ def _is_all_nested_not_none(input_data): class QuantTensor(QuantTensorBase): - def __new__( - cls, - value=None, - scale=None, - zero_point=None, - bit_width=None, - signed=None, - training=None, - _allow_empty=False): - - if scale is not None and not isinstance(scale, torch.Tensor): + def __new__(cls, value, scale, zero_point, bit_width, signed, training): + + if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) - if zero_point is not None and not isinstance(zero_point, torch.Tensor): + if not isinstance(zero_point, torch.Tensor): zero_point = torch.tensor(zero_point, dtype=torch.float) - if bit_width is not None and not isinstance(bit_width, torch.Tensor): + if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if signed is not None and not isinstance(signed, torch.Tensor): + if not isinstance(signed, torch.Tensor): signed = torch.tensor(signed, dtype=torch.bool) - if training is not None and not isinstance(training, torch.Tensor): + if not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) - - if _allow_empty: - warnings.warn( - "Empty QuantTensor are deprecated and will be removed in a future version") - # elif value is not None and scale is not None and zero_point is not None: - # is_int = torch.allclose(torch.round(int_value), int_value) - # if not is_int: - # quant_tensor = quant_tensor.set(int_value = torch.round(int_value / scale + zero_point)) - # elif int_value is None and value is not None: - # pass - elif not _allow_empty and (scale is None or bit_width is None or zero_point is None): - raise RuntimeError("To create an emtpy QuantTensor, set _allow_empty=True") - quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) return quant_tensor diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 538e836e8..783f906b7 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -410,7 +410,14 @@ def forward(self, x): return self.lstm(x) torch.random.manual_seed(SEED) - module = Model() + if return_quant_tensor and (io_quantizer is None or signed_act_quantizer is None): + with pytest.raises( + RuntimeError, + match="To return a valid QuantTensor, specify a io_quant and cell_state_quant"): + module = Model() + module = None + else: + module = Model() in_size = (FEATURES, 1, IN_CH) inp = torch.randn(in_size) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index b7d73ee18..2b74b0c10 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -77,7 +77,6 @@ def test_quant_wbiol(model_input, current_cases): @pytest_cases.parametrize_with_cases( 'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full]) def test_quant_lstm_rnn_full(model_input, current_cases): - model, input = model_input cases_generator_func = current_cases['model_input'][1] case_id = get_case_id(cases_generator_func) @@ -85,7 +84,11 @@ def test_quant_lstm_rnn_full(model_input, current_cases): kwargs = parse_args(args) is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + return_quant_tensor = kwargs['return_quant_tensor'] + if return_quant_tensor and (kwargs['io_quant'] is None or kwargs['signed_act'] is None): + pytest.skip("Invalid config") + model, input = model_input if (kwargs['bias_quant'] == 'quant_external') and ( \ (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): @@ -103,23 +106,16 @@ def test_quant_lstm_rnn_full(model_input, current_cases): else: output, h = output c = None - return_quant_tensor = kwargs['return_quant_tensor'] if return_quant_tensor: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None + assert output.scale is not None + assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(h, QuantTensor) assert h.scale is not None assert h.bit_width is not None @@ -127,7 +123,7 @@ def test_quant_lstm_rnn_full(model_input, current_cases): assert isinstance(h, torch.Tensor) if c is not None: - if kwargs['signed_act'] is None or not return_quant_tensor: + if not return_quant_tensor: assert isinstance(c, torch.Tensor) else: assert isinstance(c, QuantTensor) @@ -163,30 +159,28 @@ def test_quant_lstm_rnn(model_input, current_cases): else: output, h = output c = None - return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None + return_quant_tensor = kwargs['return_quant_tensor'] if return_quant_tensor: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None + assert output.scale is not None + assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(h, QuantTensor) + assert h.scale is not None + assert h.bit_width is not None else: assert isinstance(h, torch.Tensor) if c is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(c, QuantTensor) + assert c.scale is not None + assert c.bit_width is not None else: assert isinstance(c, torch.Tensor) @@ -218,12 +212,7 @@ def test_quant_mha(model_input, current_cases): if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None + assert output.scale is not None + assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) From 66ed2ba7c4ab3492d0299820b0629f560a3d5c71 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 23:21:11 +0000 Subject: [PATCH 15/32] Fix for empty signed/training --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 203 +++++++++--------- notebooks/03_anatomy_of_a_quantizer.ipynb | 109 ++++------ src/brevitas/proxy/runtime_quant.py | 3 +- 3 files changed, 149 insertions(+), 166 deletions(-) diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index ef7fd28ca..c9a6d052d 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -161,17 +161,17 @@ { "data": { "text/plain": [ - "tensor([[[[-0.3189, -0.2848, -0.0037],\n", - " [ 0.2287, 0.7919, -0.2949],\n", - " [ 0.7699, 0.6641, -0.1161]],\n", + "tensor([[[[ 1.0093, 0.4820, 0.0156],\n", + " [-0.1535, -0.2748, -0.9393],\n", + " [-1.0662, 0.2397, 0.0932]],\n", "\n", - " [[-0.0886, -0.1660, 1.7264],\n", - " [ 0.8113, 0.8065, -0.8843],\n", - " [-0.3388, -0.1821, -0.3209]],\n", + " [[ 0.6932, -0.2772, 0.0703],\n", + " [ 0.2536, 0.1734, -0.3745],\n", + " [-0.5633, 0.2231, -0.6844]],\n", "\n", - " [[ 0.4528, -0.1083, 1.2154],\n", - " [ 1.4329, 1.5554, 1.5001],\n", - " [ 1.0284, 1.4550, 0.5717]]]], grad_fn=)" + " [[-0.2607, 0.2174, -0.0522],\n", + " [ 0.1215, -0.3744, -0.5880],\n", + " [-0.3104, -0.6930, 0.5322]]]], grad_fn=)" ] }, "execution_count": 4, @@ -243,31 +243,31 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0653, 0.0989, 0.0071],\n", - " [-0.1871, -0.0247, -0.0671],\n", - " [ 0.1642, 0.1624, 0.0053]],\n", + "QuantTensor(value=tensor([[[[ 0.0236, 0.1599, 0.1799],\n", + " [-0.0545, 0.2144, 0.2126],\n", + " [-0.1363, -0.2271, -0.1526]],\n", "\n", - " [[ 0.1306, -0.0335, 0.1448],\n", - " [ 0.1483, -0.0671, -0.2101],\n", - " [ 0.1713, -0.1465, -0.1448]]],\n", + " [[-0.0872, -0.0091, -0.1090],\n", + " [ 0.0690, -0.0327, 0.2289],\n", + " [ 0.2307, 0.0073, -0.1326]]],\n", "\n", "\n", - " [[[-0.1448, 0.0600, -0.1201],\n", - " [ 0.1218, -0.1642, 0.1889],\n", - " [ 0.0618, 0.2101, -0.2242]],\n", + " [[[-0.0254, 0.0418, -0.0363],\n", + " [-0.2053, 0.2071, -0.1163],\n", + " [-0.1163, -0.1653, 0.0109]],\n", "\n", - " [[-0.0600, 0.0530, 0.0335],\n", - " [ 0.1201, 0.1571, 0.1254],\n", - " [ 0.1660, 0.0159, -0.0830]]],\n", + " [[-0.2107, -0.1199, 0.0799],\n", + " [ 0.0200, 0.0218, 0.1817],\n", + " [-0.1199, -0.0963, -0.0600]]],\n", "\n", "\n", - " [[[ 0.0106, 0.1536, 0.1730],\n", - " [ 0.1942, 0.0424, 0.2225],\n", - " [ 0.1324, 0.1907, 0.0441]],\n", + " [[[-0.0709, -0.0908, 0.1544],\n", + " [-0.0236, -0.2235, 0.2180],\n", + " [-0.0799, -0.0200, 0.0273]],\n", "\n", - " [[ 0.1942, 0.1236, 0.1889],\n", - " [-0.0124, 0.0742, -0.2048],\n", - " [ 0.1271, -0.1607, -0.1924]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1998, 0.1126, 0.1435],\n", + " [ 0.0818, 0.1399, 0.1181],\n", + " [ 0.1762, -0.1726, -0.2216]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 6, @@ -468,7 +468,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_1410693/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_2482988/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " torch.tanh(quant_tensor)\n" ] }, @@ -786,7 +786,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": { "tags": [ "raises-exception" @@ -804,7 +804,7 @@ "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:328\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 324\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 325\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 330\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 331\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:320\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 316\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 317\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 318\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 319\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 320\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 322\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 323\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 324\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 325\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } @@ -827,27 +827,26 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-2.4238e-03, -5.6598e-03, 5.1882e-03],\n", - " [-6.5582e-03, 8.9274e-03, 4.9640e-04],\n", - " [ 9.6283e-03, -1.7466e-03, -4.8311e-03]],\n", + "QuantTensor(value=tensor([[[[-0.0058, 0.0030, 0.0030],\n", + " [-0.0013, -0.0001, 0.0043],\n", + " [-0.0061, 0.0033, -0.0001]],\n", "\n", - " [[ 2.9322e-03, -3.1358e-03, -6.2727e-04],\n", - " [ 2.8722e-06, -3.7981e-03, 1.0973e-02],\n", - " [-4.1031e-03, 6.5909e-03, -4.2369e-03]],\n", + " [[ 0.0013, -0.0008, -0.0015],\n", + " [ 0.0011, 0.0012, -0.0012],\n", + " [-0.0013, -0.0020, 0.0002]],\n", "\n", - " [[ 4.1967e-03, -7.0733e-03, 1.6456e-03],\n", - " [ 1.8197e-03, -3.1683e-03, 4.8200e-03],\n", - " [-3.2585e-04, 3.1055e-03, 1.9703e-03]]]],\n", - " grad_fn=), scale=tensor([[[[1.7953e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0061, 0.0053, -0.0004],\n", + " [ 0.0028, 0.0031, -0.0038],\n", + " [ 0.0026, -0.0048, -0.0044]]]], grad_fn=), scale=tensor([[[[1.8528e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 24, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -865,26 +864,26 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.2816, -0.5271, -0.1748],\n", - " [-0.4247, -0.1575, 0.0681],\n", - " [ 0.6528, -0.5346, -0.0657]],\n", + "QuantTensor(value=tensor([[[[-0.4300, 0.1726, -0.3396],\n", + " [ 0.0307, -0.0052, -1.1685],\n", + " [-0.3160, 0.1334, -0.4459]],\n", "\n", - " [[ 0.2993, -0.3383, 0.3035],\n", - " [-0.4595, -0.6796, -0.9720],\n", - " [-0.1948, -0.5169, -0.2175]],\n", + " [[ 1.0135, 0.7129, -0.3874],\n", + " [ 0.4858, -0.6205, 0.1563],\n", + " [-0.1631, -0.2198, 0.1444]],\n", "\n", - " [[ 0.5586, 0.0665, -0.5807],\n", - " [ 0.5565, 0.1780, -0.0555],\n", - " [-0.1080, 0.0791, -0.2262]]]], grad_fn=), scale=tensor([[[[4.2009e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1.4600, 0.9106, 0.6328],\n", + " [ 0.6669, -0.1814, -0.0169],\n", + " [ 0.6581, 0.7420, -0.4884]]]], grad_fn=), scale=tensor([[[[2.9050e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 25, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -898,26 +897,26 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0058, 0.0030, 0.0030],\n", - " [-0.0013, -0.0002, 0.0043],\n", - " [-0.0061, 0.0033, -0.0001]],\n", + "QuantTensor(value=tensor([[[[-0.0015, -0.0035, 0.0003],\n", + " [-0.0054, 0.0047, 0.0055],\n", + " [ 0.0043, 0.0054, -0.0050]],\n", "\n", - " [[ 0.0013, -0.0008, -0.0015],\n", - " [ 0.0011, 0.0012, -0.0012],\n", - " [-0.0013, -0.0020, 0.0002]],\n", + " [[-0.0004, 0.0013, -0.0018],\n", + " [ 0.0055, -0.0073, 0.0023],\n", + " [-0.0053, 0.0009, 0.0032]],\n", "\n", - " [[-0.0061, 0.0053, -0.0004],\n", - " [ 0.0028, 0.0031, -0.0037],\n", - " [ 0.0027, -0.0048, -0.0044]]]], grad_fn=), scale=tensor([[[[1.7370e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.0015, -0.0002, -0.0068],\n", + " [ 0.0015, -0.0040, -0.0046],\n", + " [-0.0033, -0.0009, 0.0079]]]], grad_fn=), scale=tensor([[[[1.7377e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 26, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -937,7 +936,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": { "tags": [ "raises-exception" @@ -955,7 +954,7 @@ "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:347\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 344\u001b[0m output_signed \u001b[39m=\u001b[39m quant_input_signed \u001b[39mor\u001b[39;00m quant_weight_signed\n\u001b[1;32m 346\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 347\u001b[0m quant_bias \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias_quant(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, output_scale, output_bit_width)\n\u001b[1;32m 348\u001b[0m quant_bias_value \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(quant_bias, \u001b[39m'\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m'\u001b[39m, quant_bias)\n\u001b[1;32m 349\u001b[0m quant_bias_scale \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(quant_bias, \u001b[39m'\u001b[39m\u001b[39mscale\u001b[39m\u001b[39m'\u001b[39m, \u001b[39mNone\u001b[39;00m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:334\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 331\u001b[0m output_signed \u001b[39m=\u001b[39m quant_input\u001b[39m.\u001b[39msigned \u001b[39mor\u001b[39;00m quant_weight\u001b[39m.\u001b[39msigned\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 334\u001b[0m quant_bias \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias_quant(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, output_scale, output_bit_width)\n\u001b[1;32m 335\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcache_inference_quant_bias \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_bias,\n\u001b[1;32m 336\u001b[0m QuantTensor):\n\u001b[1;32m 337\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_cached_bias \u001b[39m=\u001b[39m _CachedIO(quant_bias\u001b[39m.\u001b[39mdetach(), metadata_only\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:206\u001b[0m, in \u001b[0;36mBiasQuantProxyFromInjector.forward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_handler \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_mode \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtensor_quant\n\u001b[1;32m 205\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mand\u001b[39;00m input_scale \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 206\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput scale required\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 207\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_bit_width \u001b[39mand\u001b[39;00m input_bit_width \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput bit-width required\u001b[39m\u001b[39m\"\u001b[39m)\n", @@ -979,26 +978,26 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[-0.4360, -0.2674, -0.4194],\n", - " [-0.2412, -0.6360, -0.6838],\n", - " [-0.5227, -0.0199, -0.1445]],\n", + "tensor([[[[-0.6938, 0.0069, 0.1652],\n", + " [-0.4801, -0.8120, 0.5233],\n", + " [ 0.4159, 0.4662, 0.2565]],\n", "\n", - " [[-0.3524, 0.8025, 0.2844],\n", - " [ 0.9945, -0.4782, 0.8064],\n", - " [ 0.5732, 0.1249, 0.3110]],\n", + " [[ 0.3206, -0.5500, -0.5254],\n", + " [ 0.1864, 1.0210, -0.3706],\n", + " [-0.1159, 0.6967, -0.0437]],\n", "\n", - " [[ 0.3223, 0.2530, 0.2753],\n", - " [ 0.5764, -0.2533, -0.0181],\n", - " [-0.4147, 0.2049, -0.9944]]]], grad_fn=)" + " [[-0.6209, -0.5257, -0.6592],\n", + " [ 0.6389, 0.2658, 0.4542],\n", + " [-0.3761, -0.7776, -0.2897]]]], grad_fn=)" ] }, - "execution_count": 28, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1023,30 +1022,30 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.6912, 0.0086, 0.1628],\n", - " [-0.4786, -0.8073, 0.5224],\n", - " [ 0.4157, 0.4686, 0.2560]],\n", + "QuantTensor(value=tensor([[[[-0.4005, 0.7588, 0.4616],\n", + " [-0.0777, -0.0651, -0.2405],\n", + " [-0.7292, 0.4504, 0.3716]],\n", "\n", - " [[ 0.3170, -0.5486, -0.5216],\n", - " [ 0.1832, 1.0217, -0.3637],\n", - " [-0.1115, 0.6974, -0.0452]],\n", + " [[ 0.4868, -0.4495, -0.1327],\n", + " [ 0.2079, -0.3236, -0.5482],\n", + " [ 0.5471, 0.1503, 0.6813]],\n", "\n", - " [[-0.6168, -0.5241, -0.6593],\n", - " [ 0.6408, 0.2674, 0.4537],\n", - " [-0.3744, -0.7771, -0.2848]]]], grad_fn=), scale=tensor([[[[3.0094e-05]]]], grad_fn=), zero_point=tensor([[[[ 339.3404]],\n", + " [[ 0.4356, -0.2319, 1.0867],\n", + " [ 0.0126, 0.7646, 0.3627],\n", + " [-0.4466, 0.5150, 0.1176]]]], grad_fn=), scale=tensor([[[[2.7130e-05]]]], grad_fn=), zero_point=tensor([[[[ 6313.4204]],\n", "\n", - " [[-4597.1797]],\n", + " [[-2667.2593]],\n", "\n", - " [[-3452.3711]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-5507.9629]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 29, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1061,7 +1060,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -1077,26 +1076,26 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[-0.2327, 0.9267, 0.6294],\n", - " [ 0.0901, 0.1027, -0.0727],\n", - " [-0.5614, 0.6182, 0.5394]],\n", + "tensor([[[[ 0.0650, 0.2496, -1.2857],\n", + " [ 1.0231, 0.0516, 0.7592],\n", + " [ 0.5882, -0.7619, 0.7604]],\n", "\n", - " [[ 0.4179, -0.5184, -0.2016],\n", - " [ 0.1390, -0.3925, -0.6171],\n", - " [ 0.4782, 0.0814, 0.6124]],\n", + " [[-0.6307, 0.1476, 1.0949],\n", + " [-0.1488, 0.0472, 0.0097],\n", + " [-0.2861, 0.0266, -0.2970]],\n", "\n", - " [[ 0.2896, -0.3779, 0.9408],\n", - " [-0.1334, 0.6186, 0.2167],\n", - " [-0.5926, 0.3690, -0.0284]]]], grad_fn=)" + " [[ 0.0580, 1.2994, 0.3841],\n", + " [ 0.2056, 0.0496, -0.7915],\n", + " [ 0.4698, -0.8724, -0.0405]]]], grad_fn=)" ] }, - "execution_count": 31, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index 3abd4d4ce..b31bae5e4 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -248,10 +248,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, -0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -293,10 +293,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[ 0.1000, 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -343,10 +343,10 @@ { "data": { "text/plain": [ - "(tensor([[-1., 1., -1., 1.],\n", - " [-1., -1., 1., -1.],\n", - " [-1., -1., 1., 1.],\n", - " [ 1., -1., -1., -1.]], grad_fn=),\n", + "(tensor([[-1., 1., 1., -1.],\n", + " [-1., 1., -1., 1.],\n", + " [-1., 1., 1., 1.],\n", + " [ 1., 1., 1., 1.]], grad_fn=),\n", " tensor(1., grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -380,10 +380,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[ 0.1000, -0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -447,37 +447,20 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000]],\n", - "\n", - " [[-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]]],\n", - "\n", - "\n", - " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "ename": "TypeError", + "evalue": "'NoneType' object cannot be interpreted as an integer", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 24\u001b[0m line \u001b[0;36m4\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantConv2d\n\u001b[1;32m 3\u001b[0m binary_weight_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m, (\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), weight_quant\u001b[39m=\u001b[39mMyBinaryWeightQuantizer)\n\u001b[0;32m----> 4\u001b[0m quant_weight \u001b[39m=\u001b[39m binary_weight_quant_conv\u001b[39m.\u001b[39;49mquant_weight()\n\u001b[1;32m 5\u001b[0m quant_weight\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/parameter.py:103\u001b[0m, in \u001b[0;36mQuantWeightMixin.quant_weight\u001b[0;34m(self, quant_input, subtensor_slice_list)\u001b[0m\n\u001b[1;32m 100\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mweight_quant(\n\u001b[1;32m 101\u001b[0m weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed)\n\u001b[1;32m 102\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 103\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight_quant(weights_to_quantize[weight_slice_tuple])\n\u001b[1;32m 104\u001b[0m \u001b[39mif\u001b[39;00m subtensor_slice_list \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 105\u001b[0m \u001b[39m# Restore the quantizer behaviour to full tensor quantization\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \u001b[39m# The modules to slice should have been cached already at this point\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_cached_sub_tensor_slice_list_modules \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m, \u001b[39m\"\u001b[39m\u001b[39mMissing cache of modules to slice.\u001b[39m\u001b[39m\"\u001b[39m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:101\u001b[0m, in \u001b[0;36mWeightQuantProxyFromInjector.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 99\u001b[0m impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_handler \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_mode \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtensor_quant\n\u001b[1;32m 100\u001b[0m out, scale, zero_point, bit_width \u001b[39m=\u001b[39m impl(x)\n\u001b[0;32m--> 101\u001b[0m \u001b[39mreturn\u001b[39;00m QuantTensor(out, scale, zero_point, bit_width, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mis_signed, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining)\n\u001b[1;32m 102\u001b[0m \u001b[39melse\u001b[39;00m: \u001b[39m# quantization disabled\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \u001b[39mreturn\u001b[39;00m x\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py:71\u001b[0m, in \u001b[0;36mQuantTensor.__new__\u001b[0;34m(cls, value, scale, zero_point, bit_width, signed, training)\u001b[0m\n\u001b[1;32m 69\u001b[0m bit_width \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(bit_width, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mfloat)\n\u001b[1;32m 70\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(signed, torch\u001b[39m.\u001b[39mTensor):\n\u001b[0;32m---> 71\u001b[0m signed \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mtensor(signed, dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mbool)\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(training, torch\u001b[39m.\u001b[39mTensor):\n\u001b[1;32m 73\u001b[0m training \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(training, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mbool)\n", + "\u001b[0;31mTypeError\u001b[0m: 'NoneType' object cannot be interpreted as an integer" + ] } ], "source": [ @@ -513,39 +496,39 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000]],\n", + "QuantTensor(value=tensor([[[[ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, -0.1000, 0.1000]],\n", "\n", " [[ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000],\n", " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]]],\n", + " [[ 0.1000, 0.1000, 0.1000],\n", + " [-0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, -0.1000]]],\n", "\n", "\n", " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", " [ 0.1000, -0.1000, -0.1000],\n", " [-0.1000, 0.1000, -0.1000]],\n", "\n", + " [[ 0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, 0.1000],\n", + " [-0.1000, 0.1000, 0.1000]],\n", + "\n", " [[-0.1000, -0.1000, -0.1000],\n", " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -561,11 +544,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "assert signed_quant_weight.is_valid == True" + "assert signed_quant_weight.is_valid" ] }, { diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b2b10c08a..359dd76b3 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -201,7 +201,8 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def bit_width(self): zhs = self._zero_hw_sentinel() - empty_imp = QuantTensor(zhs, zhs, zhs, zhs) + # Signed might or might not be defined. We just care about retrieving the bitwidth + empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) bit_width = self.__call__(empty_imp).bit_width return bit_width From 982e07274c5544e4ee0b2384745081b7604e8fe6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 23:25:20 +0000 Subject: [PATCH 16/32] Fix for notebook --- notebooks/03_anatomy_of_a_quantizer.ipynb | 347 ++++++++++------------ 1 file changed, 164 insertions(+), 183 deletions(-) diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index b31bae5e4..21a0b54f4 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -248,10 +248,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", + "(tensor([[-0.1000, -0.1000, 0.1000, 0.1000],\n", " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000, 0.1000]], grad_fn=),\n", + " [-0.1000, -0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -293,10 +293,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, -0.1000, 0.1000],\n", + " [-0.1000, -0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -343,10 +343,10 @@ { "data": { "text/plain": [ - "(tensor([[-1., 1., 1., -1.],\n", - " [-1., 1., -1., 1.],\n", - " [-1., 1., 1., 1.],\n", - " [ 1., 1., 1., 1.]], grad_fn=),\n", + "(tensor([[-1., -1., -1., 1.],\n", + " [-1., 1., -1., -1.],\n", + " [-1., -1., -1., -1.],\n", + " [-1., 1., -1., 1.]], grad_fn=),\n", " tensor(1., grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -380,10 +380,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, -0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=),\n", + "(tensor([[-0.1000, 0.1000, -0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -445,46 +445,22 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "'NoneType' object cannot be interpreted as an integer", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 24\u001b[0m line \u001b[0;36m4\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantConv2d\n\u001b[1;32m 3\u001b[0m binary_weight_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m, (\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), weight_quant\u001b[39m=\u001b[39mMyBinaryWeightQuantizer)\n\u001b[0;32m----> 4\u001b[0m quant_weight \u001b[39m=\u001b[39m binary_weight_quant_conv\u001b[39m.\u001b[39;49mquant_weight()\n\u001b[1;32m 5\u001b[0m quant_weight\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/parameter.py:103\u001b[0m, in \u001b[0;36mQuantWeightMixin.quant_weight\u001b[0;34m(self, quant_input, subtensor_slice_list)\u001b[0m\n\u001b[1;32m 100\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mweight_quant(\n\u001b[1;32m 101\u001b[0m weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed)\n\u001b[1;32m 102\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 103\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight_quant(weights_to_quantize[weight_slice_tuple])\n\u001b[1;32m 104\u001b[0m \u001b[39mif\u001b[39;00m subtensor_slice_list \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 105\u001b[0m \u001b[39m# Restore the quantizer behaviour to full tensor quantization\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \u001b[39m# The modules to slice should have been cached already at this point\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_cached_sub_tensor_slice_list_modules \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m, \u001b[39m\"\u001b[39m\u001b[39mMissing cache of modules to slice.\u001b[39m\u001b[39m\"\u001b[39m\n", - "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:101\u001b[0m, in \u001b[0;36mWeightQuantProxyFromInjector.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 99\u001b[0m impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_handler \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_mode \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtensor_quant\n\u001b[1;32m 100\u001b[0m out, scale, zero_point, bit_width \u001b[39m=\u001b[39m impl(x)\n\u001b[0;32m--> 101\u001b[0m \u001b[39mreturn\u001b[39;00m QuantTensor(out, scale, zero_point, bit_width, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mis_signed, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining)\n\u001b[1;32m 102\u001b[0m \u001b[39melse\u001b[39;00m: \u001b[39m# quantization disabled\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \u001b[39mreturn\u001b[39;00m x\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py:71\u001b[0m, in \u001b[0;36mQuantTensor.__new__\u001b[0;34m(cls, value, scale, zero_point, bit_width, signed, training)\u001b[0m\n\u001b[1;32m 69\u001b[0m bit_width \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(bit_width, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mfloat)\n\u001b[1;32m 70\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(signed, torch\u001b[39m.\u001b[39mTensor):\n\u001b[0;32m---> 71\u001b[0m signed \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mtensor(signed, dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mbool)\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(training, torch\u001b[39m.\u001b[39mTensor):\n\u001b[1;32m 73\u001b[0m training \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(training, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mbool)\n", - "\u001b[0;31mTypeError\u001b[0m: 'NoneType' object cannot be interpreted as an integer" - ] - } - ], + "outputs": [], "source": [ "from brevitas.nn import QuantConv2d\n", "\n", "binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer)\n", - "quant_weight = binary_weight_quant_conv.quant_weight()\n", - "quant_weight" + "try:\n", + " quant_weight = binary_weight_quant_conv.quant_weight()\n", + "except TypeError:\n", + " pass\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note however how the `QuantTensor` is not properly formed, as the `signed` attribute is `None`. This means that `quant_weight` is not considered valid, as the affine quantization invariant cannot be computed:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "assert not quant_weight.is_valid" + "Note however that we cannot compute the quantized weight, as the `signed` attribute is `None`." ] }, { @@ -502,30 +478,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", + "QuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000]],\n", "\n", - " [[ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", + " [[-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, -0.1000],\n", " [-0.1000, 0.1000, -0.1000]],\n", "\n", " [[ 0.1000, 0.1000, 0.1000],\n", " [-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]]],\n", - "\n", + " [-0.1000, -0.1000, 0.1000]]],\n", "\n", - " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", + " [[[-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000],\n", " [-0.1000, 0.1000, 0.1000]],\n", "\n", " [[-0.1000, -0.1000, -0.1000],\n", " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [-0.1000, -0.1000, 0.1000]],\n", + "\n", + " [[ 0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 12, @@ -562,39 +538,39 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000]],\n", + "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [[-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000]]],\n", + " [[-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, -0.1000]]],\n", "\n", "\n", - " [[[-0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", + " [[[ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", " [ 0.1000, -0.1000, 0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", - "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", + " [[-0.1000, 0.1000, 0.1000],\n", " [ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [ 0.1000, -0.1000, 0.1000]],\n", + "\n", + " [[-0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 15, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -624,7 +600,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -638,13 +614,13 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.1000, -0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[ 0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -670,19 +646,19 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.0010, -0.0010, -0.0010, 0.0010],\n", - " [ 0.0010, 0.0010, -0.0010, 0.0010],\n", - " [-0.0010, -0.0010, 0.0010, -0.0010],\n", - " [-0.0010, -0.0010, -0.0010, -0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[-0.0010, 0.0010, 0.0010, -0.0010],\n", + " [-0.0010, 0.0010, -0.0010, -0.0010],\n", + " [ 0.0010, -0.0010, -0.0010, -0.0010],\n", + " [ 0.0010, -0.0010, 0.0010, -0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -708,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -732,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": { "scrolled": true }, @@ -740,33 +716,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1918, 0.1918, 0.1918],\n", - " [ 0.1918, 0.1918, 0.1918],\n", - " [-0.1918, -0.1918, 0.1918]],\n", + "QuantTensor(value=tensor([[[[ 0.1904, -0.1904, -0.1904],\n", + " [-0.1904, 0.1904, -0.1904],\n", + " [-0.1904, 0.1904, 0.1904]],\n", "\n", - " [[-0.1918, -0.1918, 0.1918],\n", - " [-0.1918, 0.1918, -0.1918],\n", - " [ 0.1918, 0.1918, 0.1918]],\n", + " [[-0.1904, 0.1904, -0.1904],\n", + " [ 0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, -0.1904]],\n", "\n", - " [[-0.1918, 0.1918, 0.1918],\n", - " [ 0.1918, -0.1918, -0.1918],\n", - " [ 0.1918, 0.1918, 0.1918]]],\n", + " [[-0.1904, -0.1904, 0.1904],\n", + " [-0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, -0.1904]]],\n", "\n", "\n", - " [[[ 0.1918, -0.1918, 0.1918],\n", - " [-0.1918, -0.1918, 0.1918],\n", - " [ 0.1918, 0.1918, 0.1918]],\n", + " [[[-0.1904, 0.1904, 0.1904],\n", + " [ 0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, 0.1904]],\n", "\n", - " [[ 0.1918, 0.1918, 0.1918],\n", - " [ 0.1918, -0.1918, -0.1918],\n", - " [ 0.1918, 0.1918, 0.1918]],\n", + " [[ 0.1904, -0.1904, 0.1904],\n", + " [ 0.1904, 0.1904, 0.1904],\n", + " [ 0.1904, -0.1904, -0.1904]],\n", "\n", - " [[-0.1918, 0.1918, -0.1918],\n", - " [ 0.1918, -0.1918, 0.1918],\n", - " [ 0.1918, -0.1918, 0.1918]]]], grad_fn=), scale=tensor(0.1918, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1904, 0.1904, 0.1904],\n", + " [-0.1904, -0.1904, -0.1904],\n", + " [-0.1904, -0.1904, 0.1904]]]], grad_fn=), scale=tensor(0.1904, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 19, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -785,7 +761,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -794,7 +770,7 @@ "True" ] }, - "execution_count": 20, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -812,16 +788,16 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.1860, grad_fn=)" + "tensor(0.1876, grad_fn=)" ] }, - "execution_count": 21, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -856,7 +832,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 46\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m param_from_max_quant_conv\u001b[39m.\u001b[39;49mload_state_dict(float_conv\u001b[39m.\u001b[39;49mstate_dict())\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 45\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m param_from_max_quant_conv\u001b[39m.\u001b[39;49mload_state_dict(float_conv\u001b[39m.\u001b[39;49mstate_dict())\n", "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " ] @@ -914,30 +890,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1860, 0.1860, 0.1860],\n", - " [-0.1860, 0.1860, -0.1860],\n", - " [-0.1860, 0.1860, -0.1860]],\n", + "QuantTensor(value=tensor([[[[ 0.1876, 0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, 0.1876],\n", + " [-0.1876, 0.1876, -0.1876]],\n", "\n", - " [[ 0.1860, -0.1860, 0.1860],\n", - " [-0.1860, 0.1860, 0.1860],\n", - " [ 0.1860, -0.1860, -0.1860]],\n", + " [[-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, -0.1876]],\n", "\n", - " [[-0.1860, -0.1860, -0.1860],\n", - " [-0.1860, 0.1860, 0.1860],\n", - " [ 0.1860, 0.1860, -0.1860]]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [ 0.1876, 0.1876, 0.1876],\n", + " [ 0.1876, -0.1876, -0.1876]]],\n", "\n", "\n", - " [[[ 0.1860, -0.1860, 0.1860],\n", - " [-0.1860, -0.1860, 0.1860],\n", - " [-0.1860, 0.1860, -0.1860]],\n", + " [[[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1860, -0.1860, 0.1860],\n", - " [ 0.1860, -0.1860, -0.1860],\n", - " [ 0.1860, 0.1860, 0.1860]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1860, -0.1860, -0.1860],\n", - " [-0.1860, -0.1860, -0.1860],\n", - " [-0.1860, 0.1860, 0.1860]]]], grad_fn=), scale=tensor(0.1860, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, 0.1876],\n", + " [ 0.1876, 0.1876, 0.1876]]]], grad_fn=), scale=tensor(0.1876, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 24, @@ -1224,33 +1200,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1876, -0.1876, 0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876]],\n", + "QuantTensor(value=tensor([[[[-0.1903, 0.1903, -0.1903],\n", + " [ 0.1903, 0.1903, -0.1903],\n", + " [-0.1903, -0.1903, -0.1903]],\n", "\n", - " [[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, 0.1876],\n", - " [-0.1876, 0.1876, -0.1876]],\n", + " [[ 0.1903, -0.1903, -0.1903],\n", + " [ 0.1903, 0.1903, -0.1903],\n", + " [ 0.1903, -0.1903, 0.1903]],\n", "\n", - " [[ 0.1876, 0.1876, -0.1876],\n", - " [-0.1876, -0.1876, -0.1876],\n", - " [-0.1876, -0.1876, 0.1876]]],\n", + " [[-0.1903, -0.1903, -0.1903],\n", + " [-0.1903, -0.1903, 0.1903],\n", + " [-0.1903, 0.1903, -0.1903]]],\n", "\n", "\n", - " [[[ 0.1867, 0.1867, -0.1867],\n", - " [-0.1867, 0.1867, 0.1867],\n", - " [-0.1867, -0.1867, 0.1867]],\n", + " [[[ 0.1870, 0.1870, -0.1870],\n", + " [ 0.1870, 0.1870, -0.1870],\n", + " [-0.1870, 0.1870, -0.1870]],\n", "\n", - " [[-0.1867, -0.1867, -0.1867],\n", - " [-0.1867, 0.1867, 0.1867],\n", - " [ 0.1867, 0.1867, -0.1867]],\n", + " [[-0.1870, 0.1870, 0.1870],\n", + " [ 0.1870, 0.1870, 0.1870],\n", + " [ 0.1870, 0.1870, 0.1870]],\n", "\n", - " [[-0.1867, -0.1867, 0.1867],\n", - " [ 0.1867, -0.1867, 0.1867],\n", - " [ 0.1867, 0.1867, -0.1867]]]], grad_fn=), scale=tensor([[[[0.1876]]],\n", + " [[-0.1870, -0.1870, -0.1870],\n", + " [ 0.1870, -0.1870, -0.1870],\n", + " [-0.1870, -0.1870, 0.1870]]]], grad_fn=), scale=tensor([[[[0.1903]]],\n", "\n", "\n", - " [[[0.1867]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1870]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 33, @@ -1282,33 +1258,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1859, 0.1859, 0.1859],\n", - " [-0.1859, 0.1859, -0.1859],\n", - " [-0.1859, 0.1859, -0.1859]],\n", + "QuantTensor(value=tensor([[[[ 0.1873, 0.1873, -0.1873],\n", + " [ 0.1873, -0.1873, 0.1873],\n", + " [-0.1873, 0.1873, -0.1873]],\n", "\n", - " [[ 0.1859, -0.1859, 0.1859],\n", - " [-0.1859, 0.1859, 0.1859],\n", - " [ 0.1859, -0.1859, -0.1859]],\n", + " [[-0.1873, 0.1873, 0.1873],\n", + " [-0.1873, 0.1873, -0.1873],\n", + " [ 0.1873, -0.1873, -0.1873]],\n", "\n", - " [[-0.1859, -0.1859, -0.1859],\n", - " [-0.1859, 0.1859, 0.1859],\n", - " [ 0.1859, 0.1859, -0.1859]]],\n", + " [[-0.1873, -0.1873, -0.1873],\n", + " [ 0.1873, 0.1873, 0.1873],\n", + " [ 0.1873, -0.1873, -0.1873]]],\n", "\n", "\n", - " [[[ 0.1860, -0.1860, 0.1860],\n", - " [-0.1860, -0.1860, 0.1860],\n", - " [-0.1860, 0.1860, -0.1860]],\n", + " [[[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1860, -0.1860, 0.1860],\n", - " [ 0.1860, -0.1860, -0.1860],\n", - " [ 0.1860, 0.1860, 0.1860]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1860, -0.1860, -0.1860],\n", - " [-0.1860, -0.1860, -0.1860],\n", - " [-0.1860, 0.1860, 0.1860]]]], grad_fn=), scale=tensor([[[[0.1859]]],\n", + " [[ 0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, 0.1876],\n", + " [ 0.1876, 0.1876, 0.1876]]]], grad_fn=), scale=tensor([[[[0.1873]]],\n", "\n", "\n", - " [[[0.1860]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1876]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 34, @@ -1338,10 +1314,10 @@ { "data": { "text/plain": [ - "tensor([[-0.0100, 0.0100, 0.0100, -0.0100],\n", - " [-0.0100, 0.0100, 0.0100, -0.0100],\n", - " [-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [-0.0100, -0.0100, -0.0100, 0.0100]], grad_fn=)" + "tensor([[ 0.0100, 0.0100, -0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", + " [-0.0100, -0.0100, 0.0100, -0.0100]], grad_fn=)" ] }, "execution_count": 35, @@ -1381,11 +1357,11 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 76\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantIdentity\n\u001b[0;32m----> 3\u001b[0m quant_identity \u001b[39m=\u001b[39m QuantIdentity(\n\u001b[1;32m 4\u001b[0m act_quant\u001b[39m=\u001b[39;49mAdvancedActQuantizer, is_clamped\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, scaling_per_output_channel\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 75\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantIdentity\n\u001b[0;32m----> 3\u001b[0m quant_identity \u001b[39m=\u001b[39m QuantIdentity(\n\u001b[1;32m 4\u001b[0m act_quant\u001b[39m=\u001b[39;49mAdvancedActQuantizer, is_clamped\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, scaling_per_output_channel\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py:113\u001b[0m, in \u001b[0;36mQuantIdentity.__init__\u001b[0;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 109\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 110\u001b[0m act_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m Int8ActPerTensorFloat,\n\u001b[1;32m 111\u001b[0m return_quant_tensor: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 112\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 113\u001b[0m QuantNLAL\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 114\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 115\u001b[0m input_quant\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 116\u001b[0m act_impl\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 117\u001b[0m passthrough_act\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 118\u001b[0m act_quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 119\u001b[0m return_quant_tensor\u001b[39m=\u001b[39;49mreturn_quant_tensor,\n\u001b[1;32m 120\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:40\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 39\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 40\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:37\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 36\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 37\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py:118\u001b[0m, in \u001b[0;36mQuantNonLinearActMixin.__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 108\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 109\u001b[0m act_impl: Optional[Type[Module]],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 113\u001b[0m act_kwargs_prefix\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m,\n\u001b[1;32m 114\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m prefixed_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 116\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m: act_impl,\n\u001b[1;32m 117\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mpassthrough_act\u001b[39m\u001b[39m'\u001b[39m: passthrough_act}\n\u001b[0;32m--> 118\u001b[0m QuantProxyMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 120\u001b[0m quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 121\u001b[0m proxy_prefix\u001b[39m=\u001b[39;49mact_proxy_prefix,\n\u001b[1;32m 122\u001b[0m kwargs_prefix\u001b[39m=\u001b[39;49mact_kwargs_prefix,\n\u001b[1;32m 123\u001b[0m proxy_protocol\u001b[39m=\u001b[39;49mActQuantProxyProtocol,\n\u001b[1;32m 124\u001b[0m none_quant_injector\u001b[39m=\u001b[39;49mNoneActQuant,\n\u001b[1;32m 125\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprefixed_kwargs,\n\u001b[1;32m 126\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:71\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 70\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 71\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 72\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 73\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:70\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 70\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 71\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:89\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, quant_layer, quant_injector):\n\u001b[0;32m---> 89\u001b[0m QuantProxyFromInjector\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, quant_layer, quant_injector)\n\u001b[1;32m 90\u001b[0m ActQuantProxyProtocol\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_passthrough_act \u001b[39m=\u001b[39m _is_passthrough_act(quant_injector)\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:82\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 82\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_tracked_module(quant_layer)\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdisable_quant \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:120\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.add_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list\u001b[39m.\u001b[39mappend(module)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mupdate_tracked_modules()\n\u001b[0;32m--> 120\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_tensor_quant()\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTrying to add None as a parent module.\u001b[39m\u001b[39m\"\u001b[39m)\n", @@ -1419,10 +1395,10 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.0100, 0.0100, 0.0100, 0.0100],\n", - " [-0.0100, 0.0100, -0.0100, 0.0100],\n", - " [-0.0100, -0.0100, -0.0100, -0.0100],\n", - " [ 0.0100, 0.0100, 0.0100, -0.0100]], grad_fn=), scale=tensor([[0.0100],\n", + "QuantTensor(value=tensor([[-0.0100, 0.0100, 0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, -0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, -0.0100, 0.0100]], grad_fn=), scale=tensor([[0.0100],\n", " [0.0100],\n", " [0.0100],\n", " [0.0100]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" @@ -1446,6 +1422,11 @@ "source": [ "We have seen how powerful dependency injection is. In a way, it's even too expressive. For users that are not interesting in building completely custom quantizers, it can be hard to make sense of how the various components available under `brevitas.core` can be assembled together according to best practices." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { From c6f76cfa2548c9084bc0e8720ca7bcdb4dae37a1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 05:27:34 +0000 Subject: [PATCH 17/32] Review --- src/brevitas/core/stats/stats_op.py | 4 ++-- src/brevitas/nn/mixin/base.py | 8 ++------ src/brevitas/nn/quant_layer.py | 10 ++++------ src/brevitas/quant_tensor/__init__.py | 4 ---- tests/brevitas/nn/test_linear.py | 4 +++- tests/brevitas/nn/test_nn_quantizers.py | 6 +++--- 6 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index dee6011d5..fac729326 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -12,6 +12,7 @@ from brevitas import config from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int +from brevitas.quant_tensor import _unpack_quant_tensor # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate): # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) quant_value = self.proxy_forward(x) - if isinstance(quant_value, tuple): - quant_value = quant_value.value + quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) return loss diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 7ef184de5..2603e01f0 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,6 +18,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs @@ -181,10 +182,7 @@ def pack_output(self, quant_output: QuantTensor): if self.return_quant_tensor: return quant_output else: - if isinstance(quant_output, QuantTensor): - return quant_output.value - else: - return quant_output + return _unpack_quant_tensor(quant_output) class QuantRecurrentLayerMixin(ExportMixin): @@ -268,8 +266,6 @@ def maybe_quantize_input(self, inp): quant_input = inp if not self.quantize_output_only: quant_input = self.io_quant(quant_input) - # elif not isinstance(inp, QuantTensor): - # quant_input = QuantTensor(quant_input) return quant_input def maybe_quantize_state(self, inp, state, quant): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 206ab9584..42d74089d 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -136,8 +136,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - # quant_input_value = getattr(quant_input, 'value', quant_input) - out = self.export_handler(quant_input) + out = self.export_handler(_unpack_quant_tensor(quant_input)) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -298,15 +297,14 @@ def merge_bn_in(self, bn): def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: output_scale = None output_bit_width = None - output_signed = None output_zero_point = None + output_signed = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: - inp_value = getattr(inp, 'value', inp) - out = self.export_handler(inp_value) + out = self.export_handler(_unpack_quant_tensor(inp)) self._set_global_is_quant_layer(False) return out @@ -369,7 +367,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe elif self.return_quant_tensor and output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) - if not self.return_quant_tensor or not compute_output_quant_tensor: + if not compute_output_quant_tensor: quant_output = output_tensor else: quant_output = QuantTensor( diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 865211b8f..08bc5e2dc 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -20,10 +20,6 @@ BFLOAT16_IS_VALID_ATOL = 0.5 -def _get_dequantize_tensor(input): - return input.value if isinstance(input, QuantTensor) else input - - class QuantTensorBase(NamedTuple): value: Tensor scale: Optional[Tensor] diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 457b66e20..b9690ce63 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -56,5 +56,7 @@ def test_forward_bias_int(self): torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(0.0), - torch.tensor(3)) + torch.tensor(3), + signed=True, + training=False) assert mod(x) is not None diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 2b74b0c10..a10c7160e 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -123,12 +123,12 @@ def test_quant_lstm_rnn_full(model_input, current_cases): assert isinstance(h, torch.Tensor) if c is not None: - if not return_quant_tensor: - assert isinstance(c, torch.Tensor) - else: + if return_quant_tensor: assert isinstance(c, QuantTensor) assert c.scale is not None assert c.bit_width is not None + else: + assert isinstance(c, torch.Tensor) @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn]) From 78209341e8fa66b19d93473ce0afbf7245987976 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 06:28:23 +0000 Subject: [PATCH 18/32] Review 2 --- src/brevitas/nn/mixin/base.py | 11 +++++++---- src/brevitas/nn/quant_layer.py | 2 +- src/brevitas/nn/quant_rnn.py | 5 ++--- tests/brevitas/nn/nn_quantizers_fixture.py | 11 +++-------- tests/brevitas/nn/test_nn_quantizers.py | 2 -- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 2603e01f0..e4743207a 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -247,10 +247,9 @@ def gate_params_fwd(gate, quant_input): if isinstance(quant_input, QuantTensor): acc_bit_width = None # TODO if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor): - if quant_input.scale is not None and quant_weight_ih.scale is not None: - acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) - acc_scale = quant_weight_ih.scale.view(acc_scale_shape) - acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) + acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) + acc_scale = quant_weight_ih.scale.view(acc_scale_shape) + acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width) return quant_weight_ih, quant_weight_hh, quant_bias @@ -279,6 +278,8 @@ def maybe_quantize_state(self, inp, state, quant): def pack_quant_outputs(self, quant_outputs): # In export mode, quant_outputs has the shape of the output concatenated value + # Even though we check that return_quant_tensor can be enabled only with io_quant != None, + # inner layers in a deep network overrides it, so we check again. if self.export_mode: if self.return_quant_tensor and self.io_quant.is_quant_enabled: return QuantTensor( @@ -308,6 +309,8 @@ def pack_quant_outputs(self, quant_outputs): return torch.cat(outputs, dim=seq_dim) def pack_quant_state(self, quant_state, quant): + # Even though we check that return_quant_tensor can be enabled only with quant != None, + # inner layers in a deep network overrides it, so we check again. if self.export_mode: if self.return_quant_tensor and quant.is_quant_enabled: quant_state = QuantTensor( diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 42d74089d..6fb8f69e0 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -367,7 +367,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe elif self.return_quant_tensor and output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) - if not compute_output_quant_tensor: + if not self.return_quant_tensor or not compute_output_quant_tensor: quant_output = output_tensor else: quant_output = QuantTensor( diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 745aef3e4..c27bf199b 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -970,9 +970,8 @@ def __init__( **kwargs) if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant: raise RuntimeError("Concatenating cell states requires shared cell quantizers.") - if return_quant_tensor and (io_quant is None or cell_state_quant is None): - raise RuntimeError( - "To return a valid QuantTensor, specify a io_quant and cell_state_quant") + if return_quant_tensor and cell_state_quant is None: + raise RuntimeError("return_quant_tensor=True requires cell_state_quant != None.") self.cat_output_cell_states = cat_output_cell_states def forward(self, inp, hx=None, cx=None): diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 783f906b7..98a4a0d7b 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -384,6 +384,8 @@ def case_quant_lstm_full( if return_quant_tensor and io_quantizer is None: pytest.skip("return_quant_tensor cannot be True if no io_quantizer is specified") + if return_quant_tensor and signed_act_quantizer is None: + pytest.skip("return_quant_tensor cannot be True if no cell_state_quant is specified") class Model(nn.Module): @@ -410,14 +412,7 @@ def forward(self, x): return self.lstm(x) torch.random.manual_seed(SEED) - if return_quant_tensor and (io_quantizer is None or signed_act_quantizer is None): - with pytest.raises( - RuntimeError, - match="To return a valid QuantTensor, specify a io_quant and cell_state_quant"): - module = Model() - module = None - else: - module = Model() + module = Model() in_size = (FEATURES, 1, IN_CH) inp = torch.randn(in_size) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index a10c7160e..a20575cba 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -86,8 +86,6 @@ def test_quant_lstm_rnn_full(model_input, current_cases): is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] return_quant_tensor = kwargs['return_quant_tensor'] - if return_quant_tensor and (kwargs['io_quant'] is None or kwargs['signed_act'] is None): - pytest.skip("Invalid config") model, input = model_input if (kwargs['bias_quant'] == 'quant_external') and ( \ (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ From 05677a4143f304446f0df80cdd41dc10c82f8bb3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 06:43:13 +0000 Subject: [PATCH 19/32] Cleanup --- src/brevitas/export/manager.py | 5 +- src/brevitas/nn/hadamard_classifier.py | 12 ++- src/brevitas/nn/quant_avg_pool.py | 4 +- src/brevitas/nn/quant_upsample.py | 4 +- src/brevitas/quant_tensor/__init__.py | 85 ++++++++-------------- src/brevitas/quant_tensor/torch_handler.py | 5 +- tests/brevitas/nn/test_nn_quantizers.py | 10 +-- 7 files changed, 46 insertions(+), 79 deletions(-) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index f8f1189fd..bc62920f3 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -207,10 +207,9 @@ def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs): if isinstance(input, QuantTensor): inp_cache = None out_cache = None - if input.is_not_none: - inp_cache = _CachedIO(input, metadata_only=True) + inp_cache = _CachedIO(input, metadata_only=True) output = fn(input, *args, **kwargs) - if isinstance(output, QuantTensor) and output.is_not_none: + if isinstance(output, QuantTensor): out_cache = _CachedIO(output, metadata_only=True) cached_io = (inp_cache, out_cache) if fn in cls._cached_io_handler_map: diff --git a/src/brevitas/nn/hadamard_classifier.py b/src/brevitas/nn/hadamard_classifier.py index d3f22f679..e78163321 100644 --- a/src/brevitas/nn/hadamard_classifier.py +++ b/src/brevitas/nn/hadamard_classifier.py @@ -49,15 +49,13 @@ def forward(self, inp): out = inp.value / norm out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels]) out = -self.scale * out - if inp.scale is not None: + if isinstance(inp, QuantTensor): output_scale = inp.scale * self.scale / norm - if inp.bit_width is not None: output_bit_width = self.max_output_bit_width(inp.bit_width) - if (self.return_quant_tensor and inp.zero_point is not None and - (inp.zero_point != 0.0).any()): - raise RuntimeError("Computing zero point of output accumulator not supported yet.") - else: - output_zp = inp.zero_point + if (self.return_quant_tensor and inp.zero_point != 0.0).any(): + raise RuntimeError("Computing zero point of output accumulator not supported yet.") + else: + output_zp = inp.zero_point out = QuantTensor( value=out, scale=output_scale, diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index d8c83e3f2..a427a4d25 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -59,7 +59,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): return self.export_handler(x.value) x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) if self.is_trunc_quant_enabled: - assert x.is_not_none # check input quant tensor is filled with values + assert isinstance(x, QuantTensor) # check input quant tensor is filled with values # remove avg scaling rescaled_value = x.value * self._avg_scaling x = x.set(value=rescaled_value) @@ -138,7 +138,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._cached_kernel_size = k_size self._cached_kernel_stride = stride if self.is_trunc_quant_enabled: - assert y.is_not_none # check input quant tensor is filled with values + assert isinstance(y, QuantTensor) # check input quant tensor is filled with values reduce_size = reduce(mul, k_size, 1) rescaled_value = y.value * reduce_size # remove avg scaling y = y.set(value=rescaled_value) diff --git a/src/brevitas/nn/quant_upsample.py b/src/brevitas/nn/quant_upsample.py index 10727cec5..f2735abf5 100644 --- a/src/brevitas/nn/quant_upsample.py +++ b/src/brevitas/nn/quant_upsample.py @@ -45,7 +45,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) if self.mode != 'nearest': # round interpolated values to scale - assert x.scale is not None, 'Input scale factor required to interpolate correctly' + assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly' y_value = round_ste(y_value / x.scale) * x.scale y = x.set(value=y_value) return self.pack_output(y) @@ -73,7 +73,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): return out y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) # round interpolated values to scale - assert x.scale is not None, 'Input scale factor required to interpolate correctly' + assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly' y_value = round_ste(y_value / x.scale) * x.scale y = x.set(value=y_value) return self.pack_output(y) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 08bc5e2dc..0c2b021d5 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -42,17 +42,6 @@ def _unpack_quant_tensor(input_data): return input_data -def _is_all_nested_not_none(input_data): - if isinstance(input_data, QuantTensor): - return input_data.is_not_none - elif isinstance(input_data, (tuple, list)): - return all([_is_all_nested_not_none(v) for v in input_data]) - elif isinstance(input_data, dict): - return all([_is_all_nested_not_none(v) for v in input_data.values()]) - else: - return True - - class QuantTensor(QuantTensorBase): def __new__(cls, value, scale, zero_point, bit_width, signed, training): @@ -88,8 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if (func not in QUANT_TENSOR_FN_HANDLER or - not all(issubclass(t, QuantTensor) for t in types) or - not (_is_all_nested_not_none(args) and _is_all_nested_not_none(kwargs))): + not all(issubclass(t, QuantTensor) for t in types)): args = _unpack_quant_tensor(args) kwargs = _unpack_quant_tensor(kwargs) return func(*args, **kwargs) @@ -99,12 +87,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def tensor(self): return self.value - @property - def is_not_none(self): - return ( - self.value is not None and self.scale is not None and self.zero_point is not None and - self.bit_width is not None and self.signed is not None) - @property def _pre_round_int_value(self): value = self.value @@ -120,30 +102,27 @@ def _pre_round_int_value(self): @property def is_valid(self): - if self.is_not_none: - with torch.no_grad(): - pre_round_int_value = self._pre_round_int_value - rounded_int_value = torch.round(pre_round_int_value) - max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) - atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL - is_int = max_abs_diff < atol - if self.bit_width >= 2: - if self.signed: - is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() - is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() - else: - is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() - is_lower_b = (0. <= rounded_int_value).all() - return (is_int & is_upper_b & is_lower_b).item() - else: # binary case - unique_vals = rounded_int_value.unique( - sorted=False, return_counts=False, return_inverse=False) - is_binary = unique_vals.view(-1).size()[0] == 2 - is_signed = (unique_vals < 0.).any().item() - sign_match = is_signed == self.signed - return is_int.item() and is_binary and sign_match - else: - return False + with torch.no_grad(): + pre_round_int_value = self._pre_round_int_value + rounded_int_value = torch.round(pre_round_int_value) + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol + if self.bit_width >= 2: + if self.signed: + is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + else: + is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + is_lower_b = (0. <= rounded_int_value).all() + return (is_int & is_upper_b & is_lower_b).item() + else: # binary case + unique_vals = rounded_int_value.unique( + sorted=False, return_counts=False, return_inverse=False) + is_binary = unique_vals.view(-1).size()[0] == 2 + is_signed = (unique_vals < 0.).any().item() + sign_match = is_signed == self.signed + return is_int.item() and is_binary and sign_match @property def device(self): @@ -168,18 +147,18 @@ def detach_(self): def detach(self): return QuantTensor( self.value.detach(), - self.scale.detach() if self.scale is not None else None, - self.zero_point.detach() if self.zero_point is not None else None, - self.bit_width.detach() if self.bit_width is not None else None, + self.scale.detach(), + self.zero_point.detach(), + self.bit_width.detach(), self.signed, self.training) def contiguous(self): return QuantTensor( self.value.contiguous(), - self.scale.contiguous() if self.scale is not None else None, - self.zero_point.contiguous() if self.zero_point is not None else None, - self.bit_width.contiguous() if self.bit_width is not None else None, + self.scale.contiguous(), + self.zero_point.contiguous(), + self.bit_width.contiguous(), self.signed, self.training) @@ -284,7 +263,7 @@ def cat(tensors, dim, out=None): return tensors[0] else: first_qt = tensors[0] - if all([isinstance(qt, QuantTensor) and qt.is_not_none for qt in tensors]): + if all([isinstance(qt, QuantTensor) for qt in tensors]): for qt in tensors[1:]: first_qt.check_scaling_factors_same(qt) first_qt.check_zero_points_same(qt) @@ -364,7 +343,7 @@ def cpu(self, *args, **kwargs): self.training) def __add__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): self.check_scaling_factors_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 @@ -396,7 +375,7 @@ def __rmul__(self, other): return self.__mul__(other) def __mul__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): output_value = self.value * other.value output_scale = self.scale * other.scale output_bit_width = self.bit_width + other.bit_width @@ -426,7 +405,7 @@ def __str__(self): return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() max_int_denominator = 2 ** (other.bit_width - int(other.signed)) output_scale = self.scale / (other.scale * max_int_denominator) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 3b64bca89..1b6e43a37 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -23,10 +23,7 @@ def decorator(func): def quant_invariant_handler(fn, inp, *args, **kwargs): out_value = fn(inp.value, *args, **kwargs) - if inp.is_not_none: - return inp.set(value=out_value) - else: - return out_value + return inp.set(value=out_value) @implements(torch.flatten) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index a20575cba..db025364e 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -62,14 +62,8 @@ def test_quant_wbiol(model_input, current_cases): if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None + assert output.scale is not None + assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) From 1a7ee10fe8484c70380ccbbfc7362fe3f48edc9f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 08:41:41 +0000 Subject: [PATCH 20/32] Typing --- src/brevitas/nn/mixin/base.py | 4 ++-- src/brevitas/proxy/parameter_quant.py | 15 +++++++-------- src/brevitas/proxy/runtime_quant.py | 6 +++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index e4743207a..cb0cd810a 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -155,7 +155,7 @@ def quant_output_bit_width(self): else: return None - def unpack_input(self, inp: Union[Tensor, QuantTensor]): + def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) @@ -174,7 +174,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): inp = inp.rename(None) return inp - def pack_output(self, quant_output: QuantTensor): + def pack_output(self, quant_output: QuantTensor) -> Union[Tensor, QuantTensor]: if not self.training and self.cache_inference_quant_out and isinstance(quant_output, QuantTensor): self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f7f120697..1f6adf549 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -3,7 +3,7 @@ from abc import ABCMeta from abc import abstractmethod -from typing import List, Optional, Tuple +from typing import Optional, Union import torch from torch import Tensor @@ -94,7 +94,7 @@ def bit_width(self): bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width_ - def forward(self, x: torch.Tensor) -> QuantTensor: + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width = impl(x) @@ -115,13 +115,13 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> QuantTensor: + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): @@ -145,9 +145,8 @@ def pre_scale(self): def pre_zero_point(self): raise NotImplementedError - def forward( - self, x: torch.Tensor, input_bit_width: torch.Tensor, - input_is_signed: bool) -> QuantTensor: + def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, + input_is_signed: bool) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) @@ -199,7 +198,7 @@ def forward( self, x: Tensor, input_scale: Optional[Tensor] = None, - input_bit_width: Optional[Tensor] = None) -> QuantTensor: + input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 359dd76b3..0307dfd34 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -137,7 +137,7 @@ def bit_width(self): scale = self.__call__(self._zero_hw_sentinel()).bit_width return scale - def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: y = x if isinstance(y, QuantTensor): @@ -188,7 +188,7 @@ def bit_width(self): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): - def forward(self, x: QuantTensor): + def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple @@ -206,7 +206,7 @@ def bit_width(self): bit_width = self.__call__(empty_imp).bit_width return bit_width - def forward(self, x: QuantTensor): + def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: if self.export_mode: out_tuple = self.export_handler( From 81f9a828432c588bcdb6551741e66e7d08568fde Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 08:43:26 +0000 Subject: [PATCH 21/32] Removing comments --- src/brevitas/nn/quant_avg_pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index a427a4d25..79d493c43 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -59,7 +59,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): return self.export_handler(x.value) x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) if self.is_trunc_quant_enabled: - assert isinstance(x, QuantTensor) # check input quant tensor is filled with values + assert isinstance(x, QuantTensor) # remove avg scaling rescaled_value = x.value * self._avg_scaling x = x.set(value=rescaled_value) @@ -138,7 +138,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._cached_kernel_size = k_size self._cached_kernel_stride = stride if self.is_trunc_quant_enabled: - assert isinstance(y, QuantTensor) # check input quant tensor is filled with values + assert isinstance(y, QuantTensor) reduce_size = reduce(mul, k_size, 1) rescaled_value = y.value * reduce_size # remove avg scaling y = y.set(value=rescaled_value) From 55eb431a55c51a0a1a8f004a2b04315fbb4d4120 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 11 Feb 2024 15:39:25 +0000 Subject: [PATCH 22/32] Cleanup --- src/brevitas/nn/quant_layer.py | 12 +++++++----- src/brevitas/proxy/runtime_quant.py | 6 ++---- src/brevitas/quant_tensor/__init__.py | 23 ++++++++++------------- tests/brevitas/nn/test_nn_quantizers.py | 16 ---------------- 4 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 6fb8f69e0..8f80b5707 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -359,11 +359,13 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) if self.return_quant_tensor and not self.is_output_quant_enabled: - if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and - ((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())): - raise RuntimeError("Computing zero point of output accumulator not supported yet.") - elif quant_input.zero_point is not None and output_zero_point is None: - output_zero_point = quant_input.zero_point + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): + raise RuntimeError( + "Computing zero point of output accumulator not supported yet.") + else: + output_zero_point = quant_input.zero_point + elif self.return_quant_tensor and output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 0307dfd34..2c4f7cf2f 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -167,10 +167,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = y[0] return y else: - if isinstance(x, QuantTensor): # passthrough - return x - else: - return x + # If fused activation quant proxy is not enabled, return the input + return x class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 0c2b021d5..c66690c50 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -190,10 +190,7 @@ def check_input_type(tensor): @staticmethod def is_zero_zero_point(tensor): QuantTensor.check_input_type(tensor) - if tensor.zero_point is not None: - return (tensor.zero_point == 0.).all() - else: - return None + return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): if self.training is not None and self.training: @@ -318,27 +315,27 @@ def __neg__(self): def to(self, *args, **kwargs): return QuantTensor( self.value.to(*args, **kwargs), - self.scale.to(*args, **kwargs) if self.scale is not None else None, - self.zero_point.to(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.to(*args, **kwargs) if self.bit_width is not None else None, + self.scale.to(*args, **kwargs), + self.zero_point.to(*args, **kwargs), + self.bit_width.to(*args, **kwargs), self.signed, self.training) def cuda(self, *args, **kwargs): return QuantTensor( self.value.cuda(*args, **kwargs), - self.scale.cuda(*args, **kwargs) if self.scale is not None else None, - self.zero_point.cuda(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.cuda(*args, **kwargs) if self.bit_width is not None else None, + self.scale.cuda(*args, **kwargs), + self.zero_point.cuda(*args, **kwargs), + self.bit_width.cuda(*args, **kwargs), self.signed, self.training) def cpu(self, *args, **kwargs): return QuantTensor( self.value.cpu(*args, **kwargs), - self.scale.cpu(*args, **kwargs) if self.scale is not None else None, - self.zero_point.cpu(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.cpu(*args, **kwargs) if self.bit_width is not None else None, + self.scale.cpu(*args, **kwargs), + self.zero_point.cpu(*args, **kwargs), + self.bit_width.cpu(*args, **kwargs), self.signed, self.training) diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index db025364e..b0db249af 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -62,8 +62,6 @@ def test_quant_wbiol(model_input, current_cases): if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) @@ -101,24 +99,18 @@ def test_quant_lstm_rnn_full(model_input, current_cases): if return_quant_tensor: assert isinstance(output, QuantTensor) - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: if return_quant_tensor: assert isinstance(h, QuantTensor) - assert h.scale is not None - assert h.bit_width is not None else: assert isinstance(h, torch.Tensor) if c is not None: if return_quant_tensor: assert isinstance(c, QuantTensor) - assert c.scale is not None - assert c.bit_width is not None else: assert isinstance(c, torch.Tensor) @@ -155,24 +147,18 @@ def test_quant_lstm_rnn(model_input, current_cases): if return_quant_tensor: assert isinstance(output, QuantTensor) - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: if return_quant_tensor: assert isinstance(h, QuantTensor) - assert h.scale is not None - assert h.bit_width is not None else: assert isinstance(h, torch.Tensor) if c is not None: if return_quant_tensor: assert isinstance(c, QuantTensor) - assert c.scale is not None - assert c.bit_width is not None else: assert isinstance(c, torch.Tensor) @@ -204,7 +190,5 @@ def test_quant_mha(model_input, current_cases): if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) From 16e19350d54c92b1db8d45c269099e9dffc85c0f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 11:24:16 +0000 Subject: [PATCH 23/32] Fix --- src/brevitas/nn/quant_layer.py | 41 ++++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 8f80b5707..4adef2cea 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -317,9 +317,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe self.is_output_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") - if (self.return_quant_tensor or - (self.is_bias_quant_enabled and - (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))): + if (self.is_bias_quant_enabled and + (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width)): if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): output_bit_width = self.max_acc_bit_width( quant_input.bit_width, quant_weight.bit_width) @@ -338,35 +337,33 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe _unpack_quant_tensor(quant_weight), _unpack_quant_tensor(quant_bias)) - if (self.return_quant_tensor): - if output_scale is not None: - if (isinstance(quant_bias, QuantTensor) and quant_bias.scale.data_ptr() != - output_scale.data_ptr()) or not isinstance(quant_bias, QuantTensor): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - inp, channel_dim=channel_dim) - output_zero_point = -_unpack_quant_tensor(quant_bias).view( - output_scale_broadcast_shape) / output_scale - - if output_bit_width is not None and isinstance(quant_bias, QuantTensor): - output_bit_width = torch.where( - quant_bias.bit_width > output_bit_width, - quant_bias.bit_width, - output_bit_width) - output_bit_width = output_bit_width + 1 + if output_scale is not None: + if (isinstance(quant_bias, QuantTensor) and + quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance( + quant_bias, QuantTensor): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(quant_bias).view( + output_scale_broadcast_shape) / output_scale + + if output_bit_width is not None and isinstance(quant_bias, QuantTensor): + output_bit_width = torch.where( + quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) + output_bit_width = output_bit_width + 1 else: output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - if self.return_quant_tensor and not self.is_output_quant_enabled: + if not self.is_output_quant_enabled: if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( "Computing zero point of output accumulator not supported yet.") - else: + elif output_zero_point is None: output_zero_point = quant_input.zero_point - elif self.return_quant_tensor and output_zero_point is None: + elif output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) if not self.return_quant_tensor or not compute_output_quant_tensor: From f9990d0f70f65ff608f10419e7e0c045a4140b33 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 12:26:01 +0000 Subject: [PATCH 24/32] Fix --- src/brevitas/nn/quant_layer.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 4adef2cea..043ac9072 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -317,15 +317,11 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe self.is_output_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") - if (self.is_bias_quant_enabled and - (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width)): - if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - output_bit_width = self.max_acc_bit_width( - quant_input.bit_width, quant_weight.bit_width) + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) - output_scale = self.quant_output_scale_impl( - inp, quant_input.scale, quant_weight.scale) - output_signed = quant_input.signed or quant_weight.signed + output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) + output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) @@ -357,7 +353,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if not self.is_output_quant_enabled: if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): + if (quant_input.zero_point != 0.0 + ).any() or (quant_weight.zero_point != 0.0).any() and self.return_quant_tensor: raise RuntimeError( "Computing zero point of output accumulator not supported yet.") elif output_zero_point is None: @@ -366,9 +363,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe elif output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) - if not self.return_quant_tensor or not compute_output_quant_tensor: - quant_output = output_tensor - else: + if compute_output_quant_tensor: quant_output = QuantTensor( output_tensor, scale=output_scale, @@ -376,5 +371,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe bit_width=output_bit_width, signed=output_signed, training=self.training) + else: + quant_output = output_tensor + quant_output = self.output_quant(quant_output) return self.pack_output(quant_output) From 2dcbdcaa36084da02d8264b8e29dc12b5ba6a6c2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 13:16:42 +0000 Subject: [PATCH 25/32] Fix for avgpool --- src/brevitas/nn/mixin/base.py | 3 ++- src/brevitas/nn/quant_avg_pool.py | 43 +++++++++++++++++-------------- src/brevitas/nn/quant_layer.py | 1 - 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index cb0cd810a..ba072cab1 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -174,12 +174,13 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe inp = inp.rename(None) return inp - def pack_output(self, quant_output: QuantTensor) -> Union[Tensor, QuantTensor]: + def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if not self.training and self.cache_inference_quant_out and isinstance(quant_output, QuantTensor): self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: + assert isinstance(quant_output, QuantTensor) return quant_output else: return _unpack_quant_tensor(quant_output) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 79d493c43..0e167bed7 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -13,6 +13,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste from brevitas.inject.defaults import RoundTo8bit +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin.acc import AccQuantType @@ -56,15 +57,17 @@ def _avg_scaling(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: - return self.export_handler(x.value) - x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) - if self.is_trunc_quant_enabled: - assert isinstance(x, QuantTensor) - # remove avg scaling - rescaled_value = x.value * self._avg_scaling - x = x.set(value=rescaled_value) - x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) - x = self.trunc_quant(x) + return self.export_handler(_unpack_quant_tensor(x)) + if isinstance(x, QuantTensor): + x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) + if self.is_trunc_quant_enabled: + # remove avg scaling + rescaled_value = x.value * self._avg_scaling + x = x.set(value=rescaled_value) + x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) + x = self.trunc_quant(x) + else: + x = super(TruncAvgPool2d, self).forward(x) return self.pack_output(x) def max_acc_bit_width(self, input_bit_width): @@ -129,21 +132,23 @@ def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(x.value) + out = self.export_handler(_unpack_quant_tensor(x)) self._set_global_is_quant_layer(False) return out - y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) if self.cache_kernel_size_stride: self._cached_kernel_size = k_size self._cached_kernel_stride = stride - if self.is_trunc_quant_enabled: - assert isinstance(y, QuantTensor) - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) - y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) - y = self.trunc_quant(y) + if isinstance(x, QuantTensor): + y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) + k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) + if self.is_trunc_quant_enabled: + reduce_size = reduce(mul, k_size, 1) + rescaled_value = y.value * reduce_size # remove avg scaling + y = y.set(value=rescaled_value) + y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) + y = self.trunc_quant(y) + else: + y = super(TruncAdaptiveAvgPool2d, self).forward(x) return self.pack_output(y) def max_acc_bit_width(self, input_bit_width, reduce_size): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 043ac9072..ba4a474e2 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -319,7 +319,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) - output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) output_signed = quant_input.signed or quant_weight.signed From d41f96ed36ba67d944674eeb4a43f57eac8d983a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 13:22:53 +0000 Subject: [PATCH 26/32] Cleanup --- src/brevitas/nn/quant_avg_pool.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 0e167bed7..5d567d0ca 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -56,8 +56,10 @@ def _avg_scaling(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) + if self.export_mode: return self.export_handler(_unpack_quant_tensor(x)) + if isinstance(x, QuantTensor): x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) if self.is_trunc_quant_enabled: @@ -67,7 +69,9 @@ def forward(self, input: Union[Tensor, QuantTensor]): x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) x = self.trunc_quant(x) else: + assert not self.is_trunc_quant_enabled x = super(TruncAvgPool2d, self).forward(x) + return self.pack_output(x) def max_acc_bit_width(self, input_bit_width): @@ -130,14 +134,17 @@ def compute_kernel_size_stride(self, input_shape, output_shape): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) + # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler(_unpack_quant_tensor(x)) self._set_global_is_quant_layer(False) return out + if self.cache_kernel_size_stride: self._cached_kernel_size = k_size self._cached_kernel_stride = stride + if isinstance(x, QuantTensor): y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) @@ -148,7 +155,9 @@ def forward(self, input: Union[Tensor, QuantTensor]): y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) y = self.trunc_quant(y) else: + assert not self.is_trunc_quant_enabled y = super(TruncAdaptiveAvgPool2d, self).forward(x) + return self.pack_output(y) def max_acc_bit_width(self, input_bit_width, reduce_size): From 875bc927544ec4e82392a280931c71e7da40916e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 13:35:00 +0000 Subject: [PATCH 27/32] Notebook update --- notebooks/02_quant_activation_overview.ipynb | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 388f43cea..3756052db 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -312,8 +312,12 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, + "execution_count": 10, + "metadata": { + "tags": [ + "raises-exception" + ] + }, "outputs": [ { "data": { From f4c41325bbd1064cbe61dc24fa94f629525d7272 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 13:41:24 +0000 Subject: [PATCH 28/32] Notebook update --- notebooks/02_quant_activation_overview.ipynb | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 3756052db..39a0cfc14 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -349,18 +349,6 @@ "sigmoid_out_tensor" ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "from brevitas.quant_tensor import QuantTensor\n", - "\n", - "\n", - "assert not isinstance(sigmoid_out_tensor, QuantTensor)" - ] - }, { "cell_type": "markdown", "metadata": {}, From ea96a62cd9f2afc044cdebdf73e132d269785784 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Feb 2024 17:06:37 +0000 Subject: [PATCH 29/32] Fix tests --- src/brevitas/nn/quant_layer.py | 5 ++--- tests/brevitas/nn/test_nn_quantizers.py | 8 +++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index ba4a474e2..8ec4be3aa 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -350,10 +350,9 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - if not self.is_output_quant_enabled: + if not self.is_output_quant_enabled and self.return_quant_tensor: if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - if (quant_input.zero_point != 0.0 - ).any() or (quant_weight.zero_point != 0.0).any() and self.return_quant_tensor: + if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( "Computing zero point of output accumulator not supported yet.") elif output_zero_point is None: diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index b0db249af..bbee8daca 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -185,7 +185,13 @@ def test_quant_mha(model_input, current_cases): with pytest.raises(RuntimeError, match='Input scale required'): output, _ = model(inp, inp, inp) return - + elif kwargs['weight_quant'] is not None and kwargs['io_quant'] is None: + if kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor']: + with pytest.raises( + RuntimeError, + match='Computing zero point of output accumulator not supported yet.'): + output, _ = model(inp, inp, inp) + return output, _ = model(inp, inp, inp) if kwargs['return_quant_tensor']: From 2c063b390b2da00860862330b71f20cea4629807 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Feb 2024 13:04:17 +0000 Subject: [PATCH 30/32] Cleanup --- src/brevitas/export/onnx/standard/qoperator/handler/base.py | 5 +---- src/brevitas/nn/mixin/base.py | 2 +- src/brevitas/nn/quant_layer.py | 2 +- src/brevitas/proxy/runtime_quant.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index e614d2ed5..e684f32d0 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -104,10 +104,7 @@ def input_quant_symbolic_kwargs(cls, module): @classmethod def input_dequant_symbolic_kwargs(cls, module): - if module._cached_inp is not None: - return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) - else: - return None + return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) @classmethod def dequant_symbolic_kwargs_from_cached_io(cls, cached_io): diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index ba072cab1..e54ad1ecc 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -331,7 +331,7 @@ def pack_quant_state(self, quant_state, quant): quant_state[2], quant_state[3], quant.is_signed, - training=self.training) + self.training) else: quant_state = torch.unsqueeze(quant_state[0], dim=0) return quant_state diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 8ec4be3aa..f56ddd160 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -351,7 +351,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) if not self.is_output_quant_enabled and self.return_quant_tensor: - if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + if compute_output_quant_tensor: if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( "Computing zero point of output accumulator not supported yet.") diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 2c4f7cf2f..0324465c1 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -151,7 +151,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: else: y = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, - # otherwise return an empty QuantTensor + # otherwise return a simple Tensor if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): return QuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant From beef750e0d580df4857b333fd67ff121015c0da3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Feb 2024 13:14:39 +0000 Subject: [PATCH 31/32] Fix --- .../export/onnx/standard/qoperator/handler/base.py | 5 ++++- src/brevitas/export/onnx/standard/qoperator/manager.py | 7 ------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index e684f32d0..e614d2ed5 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -104,7 +104,10 @@ def input_quant_symbolic_kwargs(cls, module): @classmethod def input_dequant_symbolic_kwargs(cls, module): - return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) + if module._cached_inp is not None: + return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) + else: + return None @classmethod def dequant_symbolic_kwargs_from_cached_io(cls, cached_io): diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 12f16cba3..464d5941a 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -1,18 +1,11 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, Tuple, Union - -from packaging import version -from torch import Tensor from torch.nn import functional as F from torch.nn import Module -from brevitas import torch_version from brevitas.export.manager import _set_layer_export_handler from brevitas.export.manager import _set_layer_export_mode -from brevitas.export.onnx.manager import ONNXBaseManager -from brevitas.quant_tensor import QuantTensor from ..function import DequantizeLinearFn from ..function import IntClipFn From 58d2397f57622843850a21def8fd546b80004a19 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Feb 2024 16:00:25 +0000 Subject: [PATCH 32/32] Cleanup before merge --- src/brevitas/graph/gpxq.py | 17 ++++++----------- .../text_to_speech/melgan/res_stack_brevitas.py | 7 ++++--- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 149e8ec03..e9641a5a8 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -234,19 +234,14 @@ def process_input(self, inp): inp_training = self.layer.training # If using quantized activations, inp could be QuantTensor. In - # this case, we overwrite the metadata if it is specified. + # this case, we overwrite the metadata. if isinstance(inp, QuantTensor): if self.layer_requires_input_quant and (self.quant_input is None): - if inp.scale is not None: - inp_scale = inp.scale - if inp.zero_point is not None: - inp_zero_point = inp.zero_point - if inp.bit_width is not None: - inp_bit_width = inp.bit_width - if inp.signed is not None: - inp_signed = inp.signed - if inp.training is not None: - inp_training = inp.training + inp_scale = inp.scale + inp_zero_point = inp.zero_point + inp_bit_width = inp.bit_width + inp_signed = inp.signed + inp_training = inp.training inp = inp.value # if the layer requires an input quant and the quant input cache has diff --git a/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py b/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py index 07dcfc8f4..6356cfea3 100644 --- a/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py +++ b/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py @@ -30,6 +30,7 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.""" +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .common import * @@ -68,13 +69,13 @@ def forward(self, x): for layer in self.layers: x = self.scale_norm(x) if isinstance(x, QuantTensor): - x_unp, _, _ = x + x_unp = _unpack_quant_tensor(x) else: x_unp = x x_layer = self.scale_norm(layer(x_unp)) if isinstance(x_layer, QuantTensor): - x_layer_unp, _, _ = x_layer + x_layer_unp = _unpack_quant_tensor(x_layer) else: x_layer_unp = x_layer @@ -84,7 +85,7 @@ def forward(self, x): x = x + x_layer if isinstance(x, QuantTensor): - x, _, _ = x + x = _unpack_quant_tensor(x) return x