Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (equalize): parametrized scaling #1171

Closed
wants to merge 9 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
working but not working
Giuseppe5 authored and pablomlago committed Jan 31, 2025
commit d1ecf1aa548cac50ae8f6505c6b0be52bc84081e
419 changes: 369 additions & 50 deletions src/brevitas/graph/equalize.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
@@ -504,6 +504,7 @@ def _module_class_name(module_class_or_str):
return name


from torch.nn.utils.parametrize import type_before_parametrizations
def find_module(
model: nn.Module,
layer_map: Dict[nn.Module, Optional[Dict]],
2 changes: 1 addition & 1 deletion src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
@@ -79,4 +79,4 @@ def __init__(self, scaling_factor: Tensor, is_sink: bool) -> None:
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
# Reciprocal is done on the fly as to preserve the tie between scale and its reciprocal
scale = torch.reciprocal(self.scaling_factor) if self.is_sink else self.scaling_factor
return tensor * scale
return tensor * scale.to(tensor.device)
2 changes: 1 addition & 1 deletion src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
@@ -338,7 +338,7 @@ def generate_quantizers(
weight_quant = weight_quant.let(**{'group_size': weight_group_size})
# weight scale is converted to a standalone parameter

weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')
# weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')
# weight zero-point is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
61 changes: 55 additions & 6 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

import argparse
from contextlib import nullcontext
@@ -14,6 +16,7 @@
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils.fx import _SUPPORTED_MODELS
@@ -384,6 +387,12 @@ def quantize_llm(args, extra_args=None):
model, args.act_equalization, loader, alpha=args.act_equalization_alpha)
print("Act equalization applied.")
remove_hooks(model)
from brevitas.graph.equalize import ParametrizedActivationEqualization
obj = ParametrizedActivationEqualization(model)
offload_model(model)

obj.init_parametrized_scales(model, (), calibration_loader[0])
remove_hooks(model)

if not args.no_quantize:
name_blacklist = []
@@ -450,8 +459,8 @@ def quantize_llm(args, extra_args=None):
# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
if args.act_equalization is None and not require_fx and args.rotation is None:
model.tie_weights()
# if args.act_equalization is None and not require_fx and args.rotation is None:
# model.tie_weights()

if args.bias_corr:
model = add_zero_bias_to_linear(model)
@@ -485,9 +494,49 @@ def quantize_llm(args, extra_args=None):
with torch.no_grad():
model(**calibration_loader[0])


# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v
def get_calib_dataset(data="pileval", n_samples=128, block_size=512):
from datasets import load_dataset
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
raise NotImplementedError
dataset = dataset.shuffle(seed=42)
samples = []
n_run = 0
for data in dataset:
line = data["text"]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
print(f" * Split into {n_split} blocks")
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
model.config.use_cache = False
samples = get_calib_dataset()
with torch.no_grad():
obj.model = model
with obj:
for i in tqdm(range(obj.num_layers)):
for inp in samples:
model(inp)
obj.update()
model.config.use_cache = True

if args.optimize_rotations:
apply_rotation_optimization(
@@ -576,7 +625,7 @@ def quantize_llm(args, extra_args=None):
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
with torch.no_grad():#, quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
@@ -642,7 +691,7 @@ def parse_args(args, override_defaults={}):
parser.add_argument(
'--model',
type=str,
default="facebook/opt-125m",
default="HuggingfaceTB/SmolLM2-135M-Instruct",
help='HF model name. Default: facebook/opt-125m.')
parser.add_argument(
'--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.')
@@ -672,7 +721,7 @@ def parse_args(args, override_defaults={}):
'Block name for faster GPxQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
'--weight-bit-width', type=int, default=4, help='Weight bit width. Default: 8.')
parser.add_argument(
'--weight-param-method',
type=str,
@@ -719,7 +768,7 @@ def parse_args(args, override_defaults={}):
parser.add_argument(
'--weight-group-size',
type=int,
default=128,
default=32,
help='Group size for per_group weight quantization. Default: 128.')
parser.add_argument(
'--quantize-weight-zero-point', action='store_true', help='Quantize weight zero-point.')