Skip to content

Commit

Permalink
Feat (gpfq): parameter allocation/offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 18, 2024
1 parent 87ccb88 commit 9c57bea
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
gpfq_class: Optional[nn.Module] = None) -> None:
gpfq_class: Optional[GPxQ] = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -86,8 +86,6 @@ def __init__(
if gpfq_class is None:
gpfq_class = GPFQ
self.gpfq_class = gpfq_class
assert issubclass(gpfq_class, GPxQ), \
"Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`."

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
Expand Down Expand Up @@ -401,6 +399,8 @@ def _get_permutation_list(self, weight: Tensor):
def single_layer_update(self, percdamp: float = 0.01):
assert not self.layer.weight_quant.requires_quant_input, \
"Error: GPFQ does not support weight quantizers that require quantized inputs."
if hasattr(self.layer, "allocate_params"):
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand Down Expand Up @@ -468,6 +468,7 @@ def single_layer_update(self, percdamp: float = 0.01):
q_groups[group_index].unsqueeze(1),
self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0),
)

if hasattr(self.layer, 'offload_params'):
self.layer.offload_params(self.layer)
del self.float_input
del self.quant_input

0 comments on commit 9c57bea

Please sign in to comment.