Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (ptq): adding accumulator-aware extensions to GPxQ #1060

Merged
merged 19 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 124 additions & 115 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,110 +31,6 @@
from brevitas.quant_tensor import _unpack_quant_tensor


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.

Args:
model (Module): The model to quantize with GPFQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpfq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False

Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> gpfq_model = gpfq.model
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.update()
"""

def __init__(
self,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
gpfq_class: Optional[nn.Module] = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)

self.p = p
if gpfq_class is None:
gpfq_class = GPFQ
self.gpfq_class = gpfq_class
assert issubclass(gpfq_class, GPxQ), \
"Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`."

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return self.gpfq_class(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
Expand Down Expand Up @@ -243,7 +139,11 @@ def single_layer_update(self):
self.float_input = self.float_input.to(dev)
self.quant_input = self.quant_input.to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32)
# We don't need full Hessian, we just need the diagonal
# Summing over batch dimension
H_diag = self.quant_input.transpose(2, 1).square().sum(2)
Expand All @@ -261,7 +161,8 @@ def single_layer_update(self):
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
weight[group_index, :,
permutation_list[group_index][t]].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
Expand All @@ -272,11 +173,11 @@ def single_layer_update(self):
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, permutation_list[group_index][t]] = q_arg
weight[group_index, :, permutation_list[group_index][t]] = q_arg.to(dtype)
q = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
q[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
Expand Down Expand Up @@ -360,7 +261,7 @@ def update_batch(self, module, input, current_layer):
# if quant is not enabled, then it is the float input; if it is a float input
# then a quant input has already happened and we can update G
if not is_quant_enabled:
# Computing the normalized H matrix using CPU buffer
# Computing the normalized G matrix using CPU buffer
self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1)))
self.G += self.B
self.quant_input = None # NOTE: set back to None now that we've used it
Expand Down Expand Up @@ -401,6 +302,8 @@ def _get_permutation_list(self, weight: Tensor):
def single_layer_update(self, percdamp: float = 0.01):
assert not self.layer.weight_quant.requires_quant_input, \
"Error: GPFQ does not support weight quantizers that require quantized inputs."
if hasattr(self.layer, "allocate_params"):
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand Down Expand Up @@ -446,28 +349,134 @@ def single_layer_update(self, percdamp: float = 0.01):
permutation_list = self._get_permutation_list(weight)

U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev,
dtype=dtype) # [Groups, OC/groups, Samples]
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32) # [Groups, OC/groups, Samples]

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
i = permutation_list[group_index][t]
U[group_index] += torch.matmul(
weight[group_index, :, i].unsqueeze(1),
weight[group_index, :, i].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, i].unsqueeze(0),
) # [OC/Groups, 1] * [1, INSHAPE[1]]
norm = norms[group_index, i]
if norm > 0:
q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])
weight[group_index, :, i] = q_arg
weight[group_index, :, i] = q_arg.to(dtype)
q_groups = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q_groups[group_index].unsqueeze(1),
q_groups[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0),
)

if hasattr(self.layer, 'offload_params'):
self.layer.offload_params(self.layer)
del self.float_input
del self.quant_input


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.

Args:
model (Module): The model to quantize with GPFQ
group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group
of layer names that can be optimized in parallel. Default: None
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
create_weight_orig (bool): If True, store the original floating point weights before applying
gpfq. These weights will be used anytime quantization is disabled. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False
p (float): The percentage of processed inputs to use. Default: 1.0
return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the
forward call inside the context manager returns None. Default: False
act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False
gpfq_class (GPFQ): The uninitialized class to perform GPFQ. Default: `brevitas.graph.gpfq.GPFQv2`,
which is the memory-efficient formulation

Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> gpfq_model = gpfq.model
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.update()
"""

def __init__(
self,
model: nn.Module,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
gpfq_class: GPFQ = GPFQv2) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
model,
group_of_parallel_layers,
inplace,
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)

self.p = p
self.gpfq_class = gpfq_class

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return self.gpfq_class(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
Loading
Loading