diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 01dc11a82..cf23d18b0 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -59,8 +59,7 @@ def __init__( self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd - self.class_implementation = GPFQ - GPFQ.p = p + self.p = p def catch_stopfwd(self, *args, **kwargs): # Collect quant input @@ -95,23 +94,39 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return GPFQ( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) + class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ - p = 0.25 - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers=1, + create_weight_orig=True, + p=0.25) -> None: if act_order: raise ValueError("Act_order is not supported in GPFQ") - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) self.float_input = None self.quantized_input = None self.index_computed = False - self.p = GPFQ.p + self.p = p def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: @@ -188,7 +203,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index b224e0a37..b10943f1b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -67,8 +67,6 @@ def __init__( self.model.forward = self.catch_stopfwd # How many subblock to use during GPTQ for each layer self.num_blocks = num_blocks - self.class_implementation = GPTQ - GPTQ.num_blocks = num_blocks def catch_stopfwd(self, *args, **kwargs): try: @@ -85,6 +83,16 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return GPTQ( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + num_blocks=self.num_blocks) + class GPTQ(GPxQ): """ @@ -104,15 +112,21 @@ class GPTQ(GPxQ): See the License for the specific language governing permissions and limitations under the License. """ - num_blocks = 100 - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers=1, + create_weight_orig=True, + num_blocks=100) -> None: + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) dev = self.layer.weight.device # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) + self.blocksize = math.ceil(self.columns / num_blocks) # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), @@ -170,7 +184,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b13c46683..1279950a8 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -98,15 +98,16 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): - gpxq = self.class_implementation( + gpxq_module_optimizer = self.initialize_module_optimizer( module, name, act_order=self.act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) - hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) + hook_fn = partial( + gpxq_module_optimizer.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) - self.gpxq_layers[name] = gpxq + self.gpxq_layers[name] = gpxq_module_optimizer if not self.use_quant_activations: self.disable_quant_inference.disable_act_quantization( self.model, is_training=self.model.training) @@ -137,7 +138,8 @@ def catch_stopfwd(self, *args, **kwargs): class GPxQ(ABC): - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None: self.layer = layer self.name = name self.act_order = act_order @@ -159,7 +161,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig self.rows = weight.shape[0] # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] - self.parallel_layers = parallel_layers + self.len_parallel_layers = len_parallel_layers self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights