Skip to content

Commit

Permalink
Feat (ptq): adding accumulator-aware extensions to GPxQ (#1060)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Oct 26, 2024
1 parent 06af14b commit 2cb8c9d
Show file tree
Hide file tree
Showing 8 changed files with 812 additions and 257 deletions.
239 changes: 124 additions & 115 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,110 +31,6 @@
from brevitas.quant_tensor import _unpack_quant_tensor


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.
Args:
model (Module): The model to quantize with GPFQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpfq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> gpfq_model = gpfq.model
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.update()
"""

def __init__(
self,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
gpfq_class: Optional[nn.Module] = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)

self.p = p
if gpfq_class is None:
gpfq_class = GPFQ
self.gpfq_class = gpfq_class
assert issubclass(gpfq_class, GPxQ), \
"Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`."

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return self.gpfq_class(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
Expand Down Expand Up @@ -243,7 +139,11 @@ def single_layer_update(self):
self.float_input = self.float_input.to(dev)
self.quant_input = self.quant_input.to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32)
# We don't need full Hessian, we just need the diagonal
# Summing over batch dimension
H_diag = self.quant_input.transpose(2, 1).square().sum(2)
Expand All @@ -261,7 +161,8 @@ def single_layer_update(self):
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
weight[group_index, :,
permutation_list[group_index][t]].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
Expand All @@ -272,11 +173,11 @@ def single_layer_update(self):
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, permutation_list[group_index][t]] = q_arg
weight[group_index, :, permutation_list[group_index][t]] = q_arg.to(dtype)
q = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
q[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
Expand Down Expand Up @@ -360,7 +261,7 @@ def update_batch(self, module, input, current_layer):
# if quant is not enabled, then it is the float input; if it is a float input
# then a quant input has already happened and we can update G
if not is_quant_enabled:
# Computing the normalized H matrix using CPU buffer
# Computing the normalized G matrix using CPU buffer
self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1)))
self.G += self.B
self.quant_input = None # NOTE: set back to None now that we've used it
Expand Down Expand Up @@ -401,6 +302,8 @@ def _get_permutation_list(self, weight: Tensor):
def single_layer_update(self, percdamp: float = 0.01):
assert not self.layer.weight_quant.requires_quant_input, \
"Error: GPFQ does not support weight quantizers that require quantized inputs."
if hasattr(self.layer, "allocate_params"):
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand Down Expand Up @@ -446,28 +349,134 @@ def single_layer_update(self, percdamp: float = 0.01):
permutation_list = self._get_permutation_list(weight)

U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev,
dtype=dtype) # [Groups, OC/groups, Samples]
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32) # [Groups, OC/groups, Samples]

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
i = permutation_list[group_index][t]
U[group_index] += torch.matmul(
weight[group_index, :, i].unsqueeze(1),
weight[group_index, :, i].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, i].unsqueeze(0),
) # [OC/Groups, 1] * [1, INSHAPE[1]]
norm = norms[group_index, i]
if norm > 0:
q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])
weight[group_index, :, i] = q_arg
weight[group_index, :, i] = q_arg.to(dtype)
q_groups = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q_groups[group_index].unsqueeze(1),
q_groups[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0),
)

if hasattr(self.layer, 'offload_params'):
self.layer.offload_params(self.layer)
del self.float_input
del self.quant_input


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.
Args:
model (Module): The model to quantize with GPFQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpfq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
gpfq_class (GPFQ): The uninitialized class to perform GPFQ. Default: `brevitas.graph.gpfq.GPFQv2`,
which is the memory-efficient formulation
Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> gpfq_model = gpfq.model
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.update()
"""

def __init__(
self,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
gpfq_class: GPFQ = GPFQv2) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)

self.p = p
self.gpfq_class = gpfq_class

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return self.gpfq_class(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
Loading

0 comments on commit 2cb8c9d

Please sign in to comment.