Skip to content

Commit

Permalink
renaming parallel_layers to len_parallel_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Nov 13, 2023
1 parent 57bcc67 commit b04c680
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def catch_stopfwd(self, *args, **kwargs):
return out

def initialize_module_optimizer(
self, layer, name, act_order, parallel_layers, create_weight_orig):
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
parallel_layers=parallel_layers,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)

Expand All @@ -115,14 +115,14 @@ def __init__(
layer,
name,
act_order,
parallel_layers=1,
len_parallel_layers=1,
create_weight_orig=True,
p=0.25) -> None:

if act_order:
raise ValueError("Act_order is not supported in GPFQ")

super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)
self.float_input = None
self.quantized_input = None
self.index_computed = False
Expand Down Expand Up @@ -203,7 +203,7 @@ def update_batch(self, module, input, current_layer):
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
current_layer.forward_count += 1
if current_layer.forward_count == self.parallel_layers:
if current_layer.forward_count == self.len_parallel_layers:
current_layer.forward_count = 0
raise StopFwdException

Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def catch_stopfwd(self, *args, **kwargs):
return out

def initialize_module_optimizer(
self, layer, name, act_order, parallel_layers, create_weight_orig):
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPTQ(
layer=layer,
name=name,
act_order=act_order,
parallel_layers=parallel_layers,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
num_blocks=self.num_blocks)

Expand Down Expand Up @@ -118,10 +118,10 @@ def __init__(
layer,
name,
act_order,
parallel_layers=1,
len_parallel_layers=1,
create_weight_orig=True,
num_blocks=100) -> None:
super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

dev = self.layer.weight.device

Expand Down Expand Up @@ -184,7 +184,7 @@ def update_batch(self, module, input, current_layer):
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
current_layer.forward_count += 1
if current_layer.forward_count == self.parallel_layers:
if current_layer.forward_count == self.len_parallel_layers:
current_layer.forward_count = 0
raise StopFwdException

Expand Down
7 changes: 4 additions & 3 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __enter__(self):
module,
name,
act_order=self.act_order,
parallel_layers=len(parallel_layers),
len_parallel_layers=len(parallel_layers),
create_weight_orig=self.create_weight_orig)
hook_fn = partial(
gpxq_module_optimizer.update_batch, current_layer=self.current_layer)
Expand Down Expand Up @@ -138,7 +138,8 @@ def catch_stopfwd(self, *args, **kwargs):

class GPxQ(ABC):

def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
def __init__(
self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None:
self.layer = layer
self.name = name
self.act_order = act_order
Expand All @@ -160,7 +161,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig
self.rows = weight.shape[0]
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.parallel_layers = parallel_layers
self.len_parallel_layers = len_parallel_layers

self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
Expand Down

0 comments on commit b04c680

Please sign in to comment.