Skip to content

Commit

Permalink
Feat (gpfq): add different way of random projection
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed May 6, 2024
1 parent 977de54 commit 25e8735
Showing 1 changed file with 75 additions and 11 deletions.
86 changes: 75 additions & 11 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
import math
from math import pi
from typing import Callable, List, Optional

import numpy as np
import torch
from torch.fft import fft
from torch.fft import fftn
import torch.nn as nn
import unfoldNd

Expand All @@ -18,7 +22,63 @@
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn

TARGET_DIM = 512
TARGET_DIM = 1024
USE_MEM_EFF_RP = True


def dct_pt(x, norm='ortho'):
"""
Adapted from `https://github.com/zh217/torch-dct`.
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.reshape(-1, N)

v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

Vc = torch.view_as_real(fft(v, dim=1, norm=norm))

k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * torch.tensor(pi) / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)

V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

if norm == 'ortho':
V[:, 0] /= torch.sqrt(torch.tensor(N)) * 2
V[:, 1:] /= torch.sqrt(torch.tensor(N / 2)) * 2

V = 2 * V.view(*x_shape)

return V


def random_projection(float_input, quantized_input, target_dim):
# our dim to reduce is the last index
# creating a diag tensor of 1/-1 entries
ones = torch.randint(
-1, 1, (quantized_input.size(1), 1), dtype=torch.int8, device=float_input.device)
ones[ones == 0] = 1
# randomly multiply each row,

f_out = float_input * ones
q_out = quantized_input * ones
# apply dite cosine transform
f_out = dct_pt(f_out, norm='ortho')
q_out = dct_pt(q_out, norm='ortho')
# now apply q_R, a random r_n operator, so only keep target_dim of the m rows at random
# now apply q_R, a random q_restriction operator, so only keep target_dim of the m rows at random
perm = torch.randperm(f_out.size(1))
idx = perm[:target_dim].to(float_input.device)
return torch.index_select(f_out, dim=1, index=idx), torch.index_select(q_out, dim=1, index=idx)


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -248,16 +308,20 @@ def single_layer_update(self):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
# use random projection to reduce dimensionality
n = self.quantized_input.size(1)
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / n, size=(TARGET_DIM, n), device=dev)
self.quantized_input = torch.transpose(self.quantized_input, 1, 2) @ R.T
self.float_input = torch.transpose(self.float_input, 1, 2) @ R.T
del R
# reshape back
self.quantized_input = torch.transpose(self.quantized_input, 1, 2).to(dev)
self.float_input = torch.transpose(self.float_input, 1, 2).to(dev)
if USE_MEM_EFF_RP:
self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, TARGET_DIM)
else:
# use random projection to reduce dimensionality
n = self.quantized_input.size(1)
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(TARGET_DIM, n), device=dev)
# R = torch.randn(TARGET_DIM, n, device=dev)
self.quantized_input = torch.transpose(self.quantized_input, 1, 2) @ R.T
self.float_input = torch.transpose(self.float_input, 1, 2) @ R.T
del R
# reshape back
self.quantized_input = torch.transpose(self.quantized_input, 1, 2).to(dev)
self.float_input = torch.transpose(self.float_input, 1, 2).to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
# We don't need full Hessian, we just need the diagonal
Expand Down

0 comments on commit 25e8735

Please sign in to comment.