From 52d2b599e3cf71e3b23b20dfa73bb878c1419931 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 05:27:34 +0000 Subject: [PATCH] Review --- src/brevitas/core/stats/stats_op.py | 4 ++-- src/brevitas/nn/mixin/base.py | 6 ++---- src/brevitas/nn/quant_layer.py | 10 ++++------ src/brevitas/nn/quant_max_pool.py | 4 ++-- src/brevitas/quant_tensor/__init__.py | 4 ---- tests/brevitas/nn/test_linear.py | 4 +++- tests/brevitas/nn/test_nn_quantizers.py | 6 +++--- 7 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index dee6011d5..fac729326 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -12,6 +12,7 @@ from brevitas import config from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int +from brevitas.quant_tensor import _unpack_quant_tensor # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate): # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) quant_value = self.proxy_forward(x) - if isinstance(quant_value, tuple): - quant_value = quant_value.value + quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) return loss diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 7ef184de5..6ce8a10d2 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,6 +18,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs @@ -181,10 +182,7 @@ def pack_output(self, quant_output: QuantTensor): if self.return_quant_tensor: return quant_output else: - if isinstance(quant_output, QuantTensor): - return quant_output.value - else: - return quant_output + return _unpack_quant_tensor(quant_output) class QuantRecurrentLayerMixin(ExportMixin): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 206ab9584..42d74089d 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -136,8 +136,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - # quant_input_value = getattr(quant_input, 'value', quant_input) - out = self.export_handler(quant_input) + out = self.export_handler(_unpack_quant_tensor(quant_input)) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -298,15 +297,14 @@ def merge_bn_in(self, bn): def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: output_scale = None output_bit_width = None - output_signed = None output_zero_point = None + output_signed = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: - inp_value = getattr(inp, 'value', inp) - out = self.export_handler(inp_value) + out = self.export_handler(_unpack_quant_tensor(inp)) self._set_global_is_quant_layer(False) return out @@ -369,7 +367,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe elif self.return_quant_tensor and output_zero_point is None: output_zero_point = torch.zeros(1).type_as(output_tensor) - if not self.return_quant_tensor or not compute_output_quant_tensor: + if not compute_output_quant_tensor: quant_output = output_tensor else: quant_output = QuantTensor( diff --git a/src/brevitas/nn/quant_max_pool.py b/src/brevitas/nn/quant_max_pool.py index 19cb2216c..b0ca2945e 100644 --- a/src/brevitas/nn/quant_max_pool.py +++ b/src/brevitas/nn/quant_max_pool.py @@ -7,6 +7,7 @@ from torch.nn import MaxPool1d from torch.nn import MaxPool2d +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin.base import QuantLayerMixin @@ -81,8 +82,7 @@ def requires_export_handler(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: - x_value = getattr(x, 'value', x) - out = self.export_handler(x_value) + out = self.export_handler(_unpack_quant_tensor(x)) self._set_global_is_quant_layer(False) return out x = x.set(value=super().forward(x.value)) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 865211b8f..08bc5e2dc 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -20,10 +20,6 @@ BFLOAT16_IS_VALID_ATOL = 0.5 -def _get_dequantize_tensor(input): - return input.value if isinstance(input, QuantTensor) else input - - class QuantTensorBase(NamedTuple): value: Tensor scale: Optional[Tensor] diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 457b66e20..b9690ce63 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -56,5 +56,7 @@ def test_forward_bias_int(self): torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(0.0), - torch.tensor(3)) + torch.tensor(3), + signed=True, + training=False) assert mod(x) is not None diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 2b74b0c10..a10c7160e 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -123,12 +123,12 @@ def test_quant_lstm_rnn_full(model_input, current_cases): assert isinstance(h, torch.Tensor) if c is not None: - if not return_quant_tensor: - assert isinstance(c, torch.Tensor) - else: + if return_quant_tensor: assert isinstance(c, QuantTensor) assert c.scale is not None assert c.bit_width is not None + else: + assert isinstance(c, torch.Tensor) @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn])