Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rotation based equalization #1061

Merged
merged 14 commits into from
Nov 13, 2024
557 changes: 480 additions & 77 deletions src/brevitas/graph/equalize.py

Large diffs are not rendered by default.

168 changes: 168 additions & 0 deletions src/brevitas/graph/hadamard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
# Licensed under Apache License 2.0.

import math
import os
import pathlib

try:
import fast_hadamard_transform
except:
fast_hadamard_transform = None
import torch

# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py


def get_hadK(n, transpose=False):
parent = pathlib.Path(os.path.abspath(__file__)).parent
# hadamard matrices for had12, had36.pal2, had52,will,
# # had60.pal, had108.pal, had140.pal, had156.will, had172.will:
# http://www.neilsloane.com/hadamard/index.html
tensors = torch.load(str(parent) + '/hadamard_tensors.pt')
tensors = {k: v.to(torch.float) for k, v in tensors.items()}
hadK, K = None, None
if n % 172 == 0: # llama-2-7b up
assert (is_pow2(n // 172))
K = 172
hadK = tensors['get_had172'].T if transpose else tensors['get_had172']
elif n % 156 == 0: # llama-1-30b 3x hidden
assert (is_pow2(n // 156))
K = 156
hadK = tensors['get_had156'].T if transpose else tensors['get_had156']
elif n % 140 == 0: # llama-1-30b intermediate
assert (is_pow2(n // 140))
K = 140
hadK = tensors['get_had140'].T if transpose else tensors['get_had140']
elif n % 108 == 0: # llama-1-13b intermediate
assert (is_pow2(n // 108))
K = 108
hadK = tensors['get_had108'].T if transpose else tensors['get_had108']
elif n % 60 == 0: # llama-1-13b 3x hidden
assert (is_pow2(n // 60))
K = 60
hadK = tensors['get_had60'].T if transpose else tensors['get_had60']
elif n % 52 == 0: # llama-1-13b 1x hidden
assert (is_pow2(n // 52))
K = 52
hadK = tensors['get_had52'].T if transpose else tensors['get_had52']
elif n % 36 == 0:
assert (is_pow2(n // 36))
K = 36
hadK = tensors['get_had36'].T if transpose else tensors['get_had36']
elif n % 28 == 0:
assert (is_pow2(n // 28))
K = 28
hadK = tensors['get_had28'].T if transpose else tensors['get_had28']
elif n % 40 == 0:
assert (is_pow2(n // 40))
K = 40
hadK = tensors['get_had40'].T if transpose else tensors['get_had40']
elif n % 20 == 0:
assert (is_pow2(n // 20))
K = 20
hadK = tensors['get_had20'].T if transpose else tensors['get_had20']
elif n % 12 == 0:
assert (is_pow2(n // 12))
K = 12
hadK = tensors['get_had12'].T if transpose else tensors['get_had12']
else:
assert (is_pow2(n))
K = 1

return hadK, K


def matmul_hadU(X, transpose=False):
n = X.shape[-1]
hadK, K = get_hadK(n, transpose)
input = X.clone().view(-1, n, 1)
output = input.clone()
while input.shape[1] > K:
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
output = output.view(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.view(input.shape[0], input.shape[1], -1)
(input, output) = (output, input)
del output

if K > 1:
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead
input = hadK.view(1, K, K).to(input) @ input

return input.view(X.shape) / torch.tensor(n).sqrt()


def matmul_hadUt(X):
return matmul_hadU(X, transpose=True)


def random_hadamard_matrix(size, device):
# See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation"
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return matmul_hadU(Q).to(device)


def matmul_hadU_cuda(X, hadK, K):
n = X.shape[-1]
if K == 1:
return fast_hadamard_transform.hadamard_transform(
X.contiguous(), 1.0 / torch.tensor(n).sqrt())
# if transpose:
# hadK = hadK.T.contiguous()
input = X.view(*X.shape[:-1], K, n // K)
input = fast_hadamard_transform.hadamard_transform(
input.contiguous(), 1.0 / torch.tensor(n).sqrt())
input = hadK.to(input.device).to(input.dtype) @ input
return input.reshape(X.shape)


def matmul_hadUt_cuda(X, hadK, K):
return matmul_hadU_cuda(X, hadK, K, transpose=True)


def apply_exact_had_to_linear(module, had_dim=-1, output=False):
assert isinstance(module, torch.nn.Linear)
in_features, out_features = module.in_features, module.out_features

if had_dim != -1:
assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!"

W_ = module.weight.data
dtype = W_.dtype
dev = W_.device
init_shape = W_.shape
W_ = W_.float().cuda()

if had_dim == -1:
if output:
had_K, K = get_hadK(out_features)
W_ = matmul_hadU_cuda(W_.t(), had_K, K).t()
if not output:
had_K, K = get_hadK(in_features)
W_ = matmul_hadU_cuda(W_, had_K, K)
else:
# Apply Hadamard to the last had_dim chunks of the weights
if output:
W_ = W_.t()
transposed_shape = W_.shape
W_ = fast_hadamard_transform.hadamard_transform(
W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim),
scale=1 / math.sqrt(had_dim)).reshape(transposed_shape).t()
else:
raise NotImplementedError("Not implemented (or tested) yet!")
n = W_.shape[1]
W_ = hadamard_transform(
W_.reshape(-1, n // had_dim, had_dim),
scale=1 / math.sqrt(had_dim)).reshape(init_shape)
module.weight.data = W_.to(device=dev, dtype=dtype)


def is_pow2(n):
return (n & (n - 1) == 0) and (n > 0)
Binary file added src/brevitas/graph/hadamard_tensors.pt
Binary file not shown.
50 changes: 50 additions & 0 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import torch

from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.nn.quant_mha import QuantMultiheadAttention

try:
import fast_hadamard_transform
except:
fast_hadamard_transform = None

INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states']


Expand Down Expand Up @@ -41,3 +49,45 @@ def forward(self, *args, **kwargs):
# We convert everything to args so that hooks can work correctly
out = self.layer(*kwargs.values())
return out


class RotatedModule(torch.nn.Module):

def __init__(self, layer, had_mat=None, k=None) -> None:
super().__init__()
if had_mat is not None:
self.had_mat = torch.nn.Parameter(had_mat).cpu()
else:
self.had_mat = None
self.layer = layer
self.k = k

def forward(self, inp, **kwargs):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if is_cuda and fast_hadamard_transform is not None:
if self.had_mat is None or self.k is None:
had_K, K = get_hadK(inp.shape[-1])
else:
had_K = self.had_mat
K = self.k
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)
o = self.layer(inp)

return o


def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
return inp
13 changes: 12 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren
```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -53,7 +54,10 @@ options:
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on the chosen Dataset.
--dataset {wikitext2,c4}
Dataset to use for quantization (default: wikitext2)
Dataset to use for quantization (default: c4)
--gpxq-block-name GPXQ_BLOCK_NAME
Block name for faster GPxQ optimization. It works only
if FX is not needed (default: None)
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse,hqo}
Expand Down Expand Up @@ -121,6 +125,7 @@ options:
--act-calibration Apply activation calibration.
--bias-corr Apply bias correction.
--ln-affine-merge Merge LN affine params.
--replace-rmsnorm Replace HF RMSNorms with Torch one.
--no-quantize Disable quantization.
--no-float16 Disable float16 as base datatype and switch to
float32.
Expand All @@ -129,6 +134,12 @@ options:
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation Apply graph rotation equalization
--graph-rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
--layerwise-rotation Apply layerwise rotation equalization
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
42 changes: 34 additions & 8 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,34 @@
SPDX-License-Identifier: MIT
"""

from packaging import version
import torch
from torch import nn

from brevitas.graph.equalize import _is_reshaping_op
from brevitas import torch_version
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.equalize import LayerNormToRMS
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.utils import get_module
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


def replace_rmsnorm_with_torch(model, config):
assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater"
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)
dtype = next(model.parameters()).dtype
rewriters = [
ModuleToModuleByClass(
rms_cls,
torch.nn.RMSNorm,
normalized_shape=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype) for rms_cls in set_of_layers]
dtype = next(iter(model.parameters())).dtype
for r in rewriters:
model = r.apply(model)
model = model.to(dtype)
return model


def replace_bias(next_module, new_bias):
Expand Down Expand Up @@ -49,7 +70,7 @@ def merge_layernorm_affine_params(graph_model):
module = get_module(graph_model, node.target)
if isinstance(module, nn.LayerNorm):
for next in node.users:
while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)):
while (_is_scale_invariant_module(graph_model, next)):
next = node.next
if next.op == 'call_module':
next_module = get_module(graph_model, next.target)
Expand Down Expand Up @@ -83,8 +104,13 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(graph_model, dtype):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(graph_model, dtype):
merge_layernorm_affine_params(graph_model)
def apply_layernorm_affine_merge(graph_model):
eq = MergeLnAffine()
graph_model = eq.apply(graph_model)
return graph_model


@torch.no_grad()
def apply_layernorm_to_rmsnorm(graph_model, return_rewriters=False):
eq = LayerNormToRMS(return_rewriters)
return eq.apply(graph_model)
18 changes: 17 additions & 1 deletion src/brevitas_examples/llm/llm_quant/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from brevitas.fx.value_tracer import ValueProxy


def get_fx(model):
def get_fx(model, is_export=True):
forward_signature = inspect.signature(model.forward).parameters
if all(input_name in forward_signature
for input_name in ["input_ids", "attention_mask", "past_key_values"]):
input_names = ["input_ids", "attention_mask", "past_key_values"]
if not is_export:
input_names.remove('past_key_values')
else:
raise ValueError(
f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}"
Expand Down Expand Up @@ -106,3 +108,17 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
args, kwargs = tree_map(self.cast_cpu_float32, (args, kwargs))
out = func(*args, **kwargs)
return out


# This functions remap rewriters so match modules in a potentially different model that shares the same underlying tensors
# We rely on the fact that two versions of the same model (eager vs FX) might have different modules id (id(fx_module) != id (eager_module))
# However, the underlying tensors are still shared, so we can recostruct the mapping between the two
# modules.
def fix_rewriter(rewriters, old_model_ref, tensor_name):
for r in rewriters:
tensor_id = id(r.old_module_instance.weight)
module = [
m for m in old_model_ref.modules()
if hasattr(m, tensor_name) and id(m.weight) == tensor_id]
r.old_module_instance = module[0]
return rewriters
Loading
Loading