From 6d386d6102c33b887138c8d35e8f3f97c2c2dd59 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Oct 2023 17:12:35 +0100 Subject: [PATCH] Fix (learned_round): use of named tensor --- src/brevitas/nn/mixin/base.py | 4 ++-- .../ptq/learned_round_utils.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 454f73b3e..42c1d4911 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -172,8 +172,8 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp # Remove any naming metadata to avoid dowmstream errors - if not torch._C._get_tracing_state(): - inp.value.rename_(None) + # if not torch._C._get_tracing_state(): + # inp = QuantTensor(inp.value.rename(None), inp.scale, inp.zero_point, inp.bit_width, inp.signed, inp.training) return inp def pack_output(self, quant_output: QuantTensor): diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index f78d39b3b..598cdf314 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -53,6 +53,17 @@ def __init__(self, store_output: False): self.output_store = None def __call__(self, module, input_batch, output_batch): + input_batch = input_batch[0] + + if hasattr(input_batch, 'names') and 'N' in input_batch.names: + batch_dim = input_batch.names.index('N') + + input_batch.rename_(None) + input_batch = input_batch.transpose(0, batch_dim) + if self.store_output: + output_batch.rename_(None) + output_batch = output_batch.transpose(0, batch_dim) + if self.store_output: self.output_store = output_batch self.input_store = input_batch @@ -183,9 +194,9 @@ def save_inp_out_data( pass if store_inp: if keep_gpu: - cached[0].append(data_saver.input_store[0].detach()) + cached[0].append(data_saver.input_store.detach()) else: - cached[0].append(data_saver.input_store[0].detach().cpu()) + cached[0].append(data_saver.input_store.detach().cpu()) if store_out: if keep_gpu: cached[1].append(data_saver.output_store.detach())