Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (GPFQ): using random projection for speed up/less memory usage #964

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
import math
from math import pi
from typing import Callable, List, Optional

import numpy as np
import torch
from torch.fft import fft
from torch.fft import fftn
import torch.nn as nn
import unfoldNd

Expand All @@ -19,6 +23,24 @@
import brevitas.nn as qnn


def random_projection(
float_input: torch.Tensor, quantized_input: torch.Tensor, compression_rate: float):
# use random projection to reduce dimensionality
n = quantized_input.size(1)
target_dim = int(compression_rate * n)
dev = float_input.device
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev)
quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T
float_input = torch.transpose(float_input, 1, 2) @ R.T
del R
# reshape back
quantized_input = torch.transpose(quantized_input, 1, 2)
float_input = torch.transpose(float_input, 1, 2)

return float_input, quantized_input


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.
Expand Down Expand Up @@ -64,7 +86,8 @@ def __init__(
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None:
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True,
compression_rate: Optional[float] = 0.0) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -83,6 +106,11 @@ def __init__(
self.accumulator_bit_width = accumulator_bit_width
self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q

# selecting impl of random proj
self.compression_rate = compression_rate
if self.compression_rate < 0.0 or self.compression_rate > 1.0:
raise ValueError(f'Compression rate for random projection cannot be {compression_rate}')
fabianandresgrob marked this conversation as resolved.
Show resolved Hide resolved

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
Expand Down Expand Up @@ -127,7 +155,8 @@ def initialize_module_optimizer(
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
p=self.p,
compression_rate=self.compression_rate)
else:
return GPFA2Q(
layer=layer,
Expand All @@ -136,22 +165,26 @@ def initialize_module_optimizer(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width)
accumulator_bit_width=self.accumulator_bit_width,
compression_rate=self.compression_rate)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""

def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None:
def __init__(
self, layer, name, act_order, len_parallel_layers, create_weight_orig, p,
compression_rate) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

self.float_input = None
self.quantized_input = None
self.index_computed = False
self.p = p
self.compression_rate = compression_rate

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down Expand Up @@ -246,10 +279,12 @@ def single_layer_update(self):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
if self.compression_rate > 0.0:
self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
Expand Down Expand Up @@ -300,15 +335,17 @@ def __init__(
len_parallel_layers,
create_weight_orig,
accumulator_bit_width,
p) -> None:
p,
compression_rate) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
p=p,
compression_rate=compression_rate)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

Expand All @@ -329,6 +366,8 @@ def single_layer_update(self):
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
if self.compression_rate > 0.0:
self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)

Expand Down
12 changes: 10 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,14 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None):
def apply_gpfq(
calib_loader,
model,
act_order,
p=1.0,
use_gpfa2q=False,
accumulator_bit_width=None,
compression_rate=0.0):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand All @@ -545,7 +552,8 @@ def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumula
use_quant_activations=True,
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width) as gpfq:
accumulator_bit_width=accumulator_bit_width,
compression_rate=compression_rate) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
16 changes: 14 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def parse_type(v, default_type):
help=
'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)'
)
parser.add_argument(
'--compression-rate',
default=0.0,
type=float,
help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.'
)
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand Down Expand Up @@ -426,7 +432,12 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order)
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
compression_rate=args.compression_rate)

if args.gpfa2q:
print("Performing GPFA2Q:")
Expand All @@ -436,7 +447,8 @@ def main():
p=args.gpfq_p,
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width)
accumulator_bit_width=args.accumulator_bit_width,
compression_rate=args.compression_rate)

if args.gptq:
print("Performing GPTQ:")
Expand Down
Loading