From 25e873551992aff753696057d0f45499576c817d Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 6 May 2024 17:20:06 +0100 Subject: [PATCH] Feat (gpfq): add different way of random projection --- src/brevitas/graph/gpfq.py | 86 +++++++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 11 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 4e873121f..bd47dc9d4 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -2,10 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy +import math +from math import pi from typing import Callable, List, Optional import numpy as np import torch +from torch.fft import fft +from torch.fft import fftn import torch.nn as nn import unfoldNd @@ -18,7 +22,63 @@ from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -TARGET_DIM = 512 +TARGET_DIM = 1024 +USE_MEM_EFF_RP = True + + +def dct_pt(x, norm='ortho'): + """ + Adapted from `https://github.com/zh217/torch-dct`. + Discrete Cosine Transform, Type II (a.k.a. the DCT) + + For the meaning of the parameter `norm`, see: + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html + + :param x: the input signal + :param norm: the normalization, None or 'ortho' + :return: the DCT-II of the signal over the last dimension + """ + x_shape = x.shape + N = x_shape[-1] + x = x.reshape(-1, N) + + v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) + + Vc = torch.view_as_real(fft(v, dim=1, norm=norm)) + + k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * torch.tensor(pi) / (2 * N) + W_r = torch.cos(k) + W_i = torch.sin(k) + + V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i + + if norm == 'ortho': + V[:, 0] /= torch.sqrt(torch.tensor(N)) * 2 + V[:, 1:] /= torch.sqrt(torch.tensor(N / 2)) * 2 + + V = 2 * V.view(*x_shape) + + return V + + +def random_projection(float_input, quantized_input, target_dim): + # our dim to reduce is the last index + # creating a diag tensor of 1/-1 entries + ones = torch.randint( + -1, 1, (quantized_input.size(1), 1), dtype=torch.int8, device=float_input.device) + ones[ones == 0] = 1 + # randomly multiply each row, + + f_out = float_input * ones + q_out = quantized_input * ones + # apply dite cosine transform + f_out = dct_pt(f_out, norm='ortho') + q_out = dct_pt(q_out, norm='ortho') + # now apply q_R, a random r_n operator, so only keep target_dim of the m rows at random + # now apply q_R, a random q_restriction operator, so only keep target_dim of the m rows at random + perm = torch.randperm(f_out.size(1)) + idx = perm[:target_dim].to(float_input.device) + return torch.index_select(f_out, dim=1, index=idx), torch.index_select(q_out, dim=1, index=idx) class gpfq_mode(gpxq_mode): @@ -248,16 +308,20 @@ def single_layer_update(self): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - # use random projection to reduce dimensionality - n = self.quantized_input.size(1) - # create gaussian random matrix - R = torch.normal(mean=0.0, std=1. / n, size=(TARGET_DIM, n), device=dev) - self.quantized_input = torch.transpose(self.quantized_input, 1, 2) @ R.T - self.float_input = torch.transpose(self.float_input, 1, 2) @ R.T - del R - # reshape back - self.quantized_input = torch.transpose(self.quantized_input, 1, 2).to(dev) - self.float_input = torch.transpose(self.float_input, 1, 2).to(dev) + if USE_MEM_EFF_RP: + self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, TARGET_DIM) + else: + # use random projection to reduce dimensionality + n = self.quantized_input.size(1) + # create gaussian random matrix + R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(TARGET_DIM, n), device=dev) + # R = torch.randn(TARGET_DIM, n, device=dev) + self.quantized_input = torch.transpose(self.quantized_input, 1, 2) @ R.T + self.float_input = torch.transpose(self.float_input, 1, 2) @ R.T + del R + # reshape back + self.quantized_input = torch.transpose(self.quantized_input, 1, 2).to(dev) + self.float_input = torch.transpose(self.float_input, 1, 2).to(dev) U = torch.zeros( weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) # We don't need full Hessian, we just need the diagonal