diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 2840b0578..ef720d092 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -17,7 +17,6 @@ from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -from brevitas.quant_tensor import _unpack_quant_tensor class gpfq_mode(gpxq_mode): @@ -163,7 +162,6 @@ def update_batch(self, module, input, current_layer): is_quant_enabled = module.weight_quant.is_quant_enabled inp = self.process_input(input) - inp = _unpack_quant_tensor(inp) batch_size = inp.shape[0] # Preprocess the input to compute the Hessian diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 6594dda94..31d31433b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -21,7 +21,6 @@ from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -from brevitas.quant_tensor import _unpack_quant_tensor class gptq_mode(gpxq_mode): @@ -145,7 +144,6 @@ def update_batch(self, module, input, current_layer): # Update reference to current layer current_layer.layer_names.add(self.name) inp = self.process_input(input) - inp = _unpack_quant_tensor(inp) batch_size = inp.shape[0] # Preprocess the input to compute the Hessian