diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 454f73b3e..2d4fa97ad 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -172,8 +172,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp # Remove any naming metadata to avoid dowmstream errors + # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): - inp.value.rename_(None) + inp = inp.set(value=inp.value.rename(None)) return inp def pack_output(self, quant_output: QuantTensor): diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index f78d39b3b..00f8e472e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -36,6 +36,7 @@ from brevitas.inject.enum import FloatToIntImplType from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +from brevitas.quant_tensor import QuantTensor config.IGNORE_MISSING_KEYS = True @@ -53,6 +54,19 @@ def __init__(self, store_output: False): self.output_store = None def __call__(self, module, input_batch, output_batch): + input_batch = input_batch[0] + if isinstance(input_batch, QuantTensor): + input_batch = input_batch.value + + if hasattr(input_batch, 'names') and 'N' in input_batch.names: + batch_dim = input_batch.names.index('N') + + input_batch.rename_(None) + input_batch = input_batch.transpose(0, batch_dim) + if self.store_output: + output_batch.rename_(None) + output_batch = output_batch.transpose(0, batch_dim) + if self.store_output: self.output_store = output_batch self.input_store = input_batch @@ -183,9 +197,9 @@ def save_inp_out_data( pass if store_inp: if keep_gpu: - cached[0].append(data_saver.input_store[0].detach()) + cached[0].append(data_saver.input_store.detach()) else: - cached[0].append(data_saver.input_store[0].detach().cpu()) + cached[0].append(data_saver.input_store.detach().cpu()) if store_out: if keep_gpu: cached[1].append(data_saver.output_store.detach())