From 4e966cbf0ff505892aa0e5170203861f025d2e6d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Oct 2023 17:37:29 +0100 Subject: [PATCH] Revert name dropping --- src/brevitas/nn/mixin/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 42c1d4911..2d4fa97ad 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -172,8 +172,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp # Remove any naming metadata to avoid dowmstream errors - # if not torch._C._get_tracing_state(): - # inp = QuantTensor(inp.value.rename(None), inp.scale, inp.zero_point, inp.bit_width, inp.signed, inp.training) + # Avoid inplace operations on the input in case of forward hooks + if not torch._C._get_tracing_state(): + inp = inp.set(value=inp.value.rename(None)) return inp def pack_output(self, quant_output: QuantTensor):