diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index a04f579bc..cf23d18b0 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -95,12 +95,12 @@ def catch_stopfwd(self, *args, **kwargs): return out def initialize_module_optimizer( - self, layer, name, act_order, parallel_layers, create_weight_orig): + self, layer, name, act_order, len_parallel_layers, create_weight_orig): return GPFQ( layer=layer, name=name, act_order=act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=self.p) @@ -115,14 +115,14 @@ def __init__( layer, name, act_order, - parallel_layers=1, + len_parallel_layers=1, create_weight_orig=True, p=0.25) -> None: if act_order: raise ValueError("Act_order is not supported in GPFQ") - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) self.float_input = None self.quantized_input = None self.index_computed = False @@ -203,7 +203,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == self.parallel_layers: + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 653d60ccd..b10943f1b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -84,12 +84,12 @@ def catch_stopfwd(self, *args, **kwargs): return out def initialize_module_optimizer( - self, layer, name, act_order, parallel_layers, create_weight_orig): + self, layer, name, act_order, len_parallel_layers, create_weight_orig): return GPTQ( layer=layer, name=name, act_order=act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, num_blocks=self.num_blocks) @@ -118,10 +118,10 @@ def __init__( layer, name, act_order, - parallel_layers=1, + len_parallel_layers=1, create_weight_orig=True, num_blocks=100) -> None: - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) dev = self.layer.weight.device @@ -184,7 +184,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == self.parallel_layers: + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 7d7bb030a..1279950a8 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -102,7 +102,7 @@ def __enter__(self): module, name, act_order=self.act_order, - parallel_layers=len(parallel_layers), + len_parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) hook_fn = partial( gpxq_module_optimizer.update_batch, current_layer=self.current_layer) @@ -138,7 +138,8 @@ def catch_stopfwd(self, *args, **kwargs): class GPxQ(ABC): - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None: self.layer = layer self.name = name self.act_order = act_order @@ -160,7 +161,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig self.rows = weight.shape[0] # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] - self.parallel_layers = parallel_layers + self.len_parallel_layers = len_parallel_layers self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights