From af1eef439627db3fd7c0ca2480d275e8ab1a3f8e Mon Sep 17 00:00:00 2001 From: i-colbert Date: Wed, 4 Dec 2024 18:25:50 +0000 Subject: [PATCH] Fix (gpxq): general unpacking of quant tensor --- src/brevitas/graph/gpfq.py | 1 - src/brevitas/graph/gpxq.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index b16faaf6b..17fe24b74 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -218,7 +218,6 @@ def update_batch(self, module, input, current_layer): is_quant_enabled = module.weight_quant.is_quant_enabled inp = self.process_input(input) - inp = _unpack_quant_tensor(inp) # Preprocess the input to compute the Hessian if isinstance(self.layer, qnn.QuantLinear): diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 2af0d7f2d..df3446614 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -20,8 +20,7 @@ from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.utils import is_conv_transposed import brevitas.nn as qnn -from brevitas.quant_tensor import IntQuantTensor -from brevitas.utils.quant_utils import _CachedIO +from brevitas.quant_tensor import QuantTensor SUPPORTED_TCONV_OP = (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d) @@ -224,11 +223,11 @@ def process_input(self, inp): is_quant_enabled = self.layer.weight_quant.is_quant_enabled - # If using quantized activations, inp could be IntQuantTensor. In + # If using quantized activations, inp could be QuantTensor. In # this case, we overwrite the metadata. - if isinstance(inp, IntQuantTensor): + if isinstance(inp, QuantTensor): if is_quant_enabled and self.quant_metadata is None: - self.quant_metadata = _CachedIO(inp, metadata_only=True) + self.quant_metadata = self.layer.input_quant.cache_class(inp, metadata_only=True) inp = inp.value # If input is unbatched, add batch_size = 1