Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 5, 2024
1 parent 8222652 commit 52d2b59
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

Expand Down Expand Up @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate):
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
quant_value = self.proxy_forward(x)
if isinstance(quant_value, tuple):
quant_value = quant_value.value
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
return loss
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .utils import filter_kwargs
Expand Down Expand Up @@ -181,10 +182,7 @@ def pack_output(self, quant_output: QuantTensor):
if self.return_quant_tensor:
return quant_output
else:
if isinstance(quant_output, QuantTensor):
return quant_output.value
else:
return quant_output
return _unpack_quant_tensor(quant_output)


class QuantRecurrentLayerMixin(ExportMixin):
Expand Down
10 changes: 4 additions & 6 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
quant_input = self.input_quant(input)
# shortcut execution through the export impl during export
if self.export_mode:
# quant_input_value = getattr(quant_input, 'value', quant_input)
out = self.export_handler(quant_input)
out = self.export_handler(_unpack_quant_tensor(quant_input))
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
Expand Down Expand Up @@ -298,15 +297,14 @@ def merge_bn_in(self, bn):
def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_signed = None
output_zero_point = None
output_signed = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
if self.export_mode:
inp_value = getattr(inp, 'value', inp)
out = self.export_handler(inp_value)
out = self.export_handler(_unpack_quant_tensor(inp))
self._set_global_is_quant_layer(False)
return out

Expand Down Expand Up @@ -369,7 +367,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
elif self.return_quant_tensor and output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

if not self.return_quant_tensor or not compute_output_quant_tensor:
if not compute_output_quant_tensor:
quant_output = output_tensor
else:
quant_output = QuantTensor(
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import MaxPool1d
from torch.nn import MaxPool2d

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .mixin.base import QuantLayerMixin
Expand Down Expand Up @@ -81,8 +82,7 @@ def requires_export_handler(self):
def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
x_value = getattr(x, 'value', x)
out = self.export_handler(x_value)
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out
x = x.set(value=super().forward(x.value))
Expand Down
4 changes: 0 additions & 4 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
BFLOAT16_IS_VALID_ATOL = 0.5


def _get_dequantize_tensor(input):
return input.value if isinstance(input, QuantTensor) else input


class QuantTensorBase(NamedTuple):
value: Tensor
scale: Optional[Tensor]
Expand Down
4 changes: 3 additions & 1 deletion tests/brevitas/nn/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,7 @@ def test_forward_bias_int(self):
torch.rand(size=(3, INPUT_FEATURES)),
torch.tensor(1.0),
torch.tensor(0.0),
torch.tensor(3))
torch.tensor(3),
signed=True,
training=False)
assert mod(x) is not None
6 changes: 3 additions & 3 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def test_quant_lstm_rnn_full(model_input, current_cases):
assert isinstance(h, torch.Tensor)

if c is not None:
if not return_quant_tensor:
assert isinstance(c, torch.Tensor)
else:
if return_quant_tensor:
assert isinstance(c, QuantTensor)
assert c.scale is not None
assert c.bit_width is not None
else:
assert isinstance(c, torch.Tensor)


@pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn])
Expand Down

0 comments on commit 52d2b59

Please sign in to comment.