Skip to content

Commit

Permalink
Feat (graph/gpfq): compression with random projection (#964)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob authored May 31, 2024
1 parent 8c71e08 commit 0f60606
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
55 changes: 47 additions & 8 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
import math
from math import pi
from typing import Callable, List, Optional

import numpy as np
import torch
from torch.fft import fft
from torch.fft import fftn
import torch.nn as nn
import unfoldNd

Expand All @@ -19,6 +23,24 @@
import brevitas.nn as qnn


def random_projection(
float_input: torch.Tensor, quantized_input: torch.Tensor, compression_rate: float):
# use random projection to reduce dimensionality
n = quantized_input.size(1)
target_dim = int(compression_rate * n)
dev = float_input.device
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev)
quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T
float_input = torch.transpose(float_input, 1, 2) @ R.T
del R
# reshape back
quantized_input = torch.transpose(quantized_input, 1, 2)
float_input = torch.transpose(float_input, 1, 2)

return float_input, quantized_input


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.
Expand Down Expand Up @@ -64,7 +86,8 @@ def __init__(
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None:
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True,
compression_rate: Optional[float] = 0.0) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -83,6 +106,11 @@ def __init__(
self.accumulator_bit_width = accumulator_bit_width
self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q

# selecting impl of random proj
self.compression_rate = compression_rate
if self.compression_rate < 0.0 or self.compression_rate > 1.0:
raise ValueError('Compression rate for random projection must be between 0 and 1.')

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
Expand Down Expand Up @@ -127,7 +155,8 @@ def initialize_module_optimizer(
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
p=self.p,
compression_rate=self.compression_rate)
else:
return GPFA2Q(
layer=layer,
Expand All @@ -136,22 +165,26 @@ def initialize_module_optimizer(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width)
accumulator_bit_width=self.accumulator_bit_width,
compression_rate=self.compression_rate)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""

def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None:
def __init__(
self, layer, name, act_order, len_parallel_layers, create_weight_orig, p,
compression_rate) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

self.float_input = None
self.quantized_input = None
self.index_computed = False
self.p = p
self.compression_rate = compression_rate

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down Expand Up @@ -246,10 +279,12 @@ def single_layer_update(self):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
if self.compression_rate > 0.0:
self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
Expand Down Expand Up @@ -300,15 +335,17 @@ def __init__(
len_parallel_layers,
create_weight_orig,
accumulator_bit_width,
p) -> None:
p,
compression_rate) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
p=p,
compression_rate=compression_rate)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

Expand All @@ -329,6 +366,8 @@ def single_layer_update(self):
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
if self.compression_rate > 0.0:
self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)

Expand Down
12 changes: 10 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,14 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None):
def apply_gpfq(
calib_loader,
model,
act_order,
p=1.0,
use_gpfa2q=False,
accumulator_bit_width=None,
compression_rate=0.0):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand All @@ -545,7 +552,8 @@ def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumula
use_quant_activations=True,
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width) as gpfq:
accumulator_bit_width=accumulator_bit_width,
compression_rate=compression_rate) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
16 changes: 14 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def parse_type(v, default_type):
help=
'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)'
)
parser.add_argument(
'--compression-rate',
default=0.0,
type=float,
help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.'
)
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand Down Expand Up @@ -426,7 +432,12 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order)
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
compression_rate=args.compression_rate)

if args.gpfa2q:
print("Performing GPFA2Q:")
Expand All @@ -436,7 +447,8 @@ def main():
p=args.gpfq_p,
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width)
accumulator_bit_width=args.accumulator_bit_width,
compression_rate=args.compression_rate)

if args.gptq:
print("Performing GPTQ:")
Expand Down

0 comments on commit 0f60606

Please sign in to comment.