Skip to content

Commit

Permalink
Update for weight orig
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 27, 2023
1 parent f043fcd commit dec588f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
model,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: int = 0.25,
return_forward_output: bool = False,
Expand All @@ -51,6 +52,7 @@ def __init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)
Expand Down Expand Up @@ -100,12 +102,12 @@ class GPFQ(GPxQ):
"""
p = 0.25

def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:

if act_order:
raise ValueError("Act_order is not supported in GPFQ")

super().__init__(layer, name, act_order, parallel_layers)
super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
self.float_input = None
self.quantized_input = None
self.index_computed = False
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
model,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
num_blocks: int = 100,
return_forward_output: bool = False,
Expand All @@ -57,6 +58,7 @@ def __init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)
Expand Down Expand Up @@ -104,8 +106,8 @@ class GPTQ(GPxQ):
"""
num_blocks = 100

def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
super().__init__(layer, name, act_order, parallel_layers)
def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)

dev = self.layer.weight.device

Expand Down
15 changes: 12 additions & 3 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def __init__(
model,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
act_order: bool = False,
return_forward_output: bool = False) -> None:

if not inplace:
model = deepcopy(model)
self.model = model
self.create_weight_orig = create_weight_orig
self.use_quant_activations = use_quant_activations
self.hook_dict = dict()
self.gpxq_layers = dict()
Expand Down Expand Up @@ -97,7 +99,11 @@ def __enter__(self):
# Attach hooks for GPTQ
if self._is_module_supported(module):
gpxq = self.class_implementation(
module, name, act_order=self.act_order, parallel_layers=parallel_layers)
module,
name,
act_order=self.act_order,
parallel_layers=parallel_layers,
create_weight_orig=self.create_weight_orig)
hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
self.gpxq_layers[name] = gpxq
Expand Down Expand Up @@ -131,13 +137,16 @@ def catch_stopfwd(self, *args, **kwargs):

class GPxQ(ABC):

def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
self.layer = layer
self.name = name
self.act_order = act_order

weight = layer.weight.data
self.layer.weight_orig = deepcopy(layer.weight)

if create_weight_orig and not hasattr(self.layer, 'weight_orig'):
self.layer.register_buffer('weight_orig', layer.weight.detach().clone())

# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
Expand Down

0 comments on commit dec588f

Please sign in to comment.