-
Notifications
You must be signed in to change notification settings - Fork 199
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Examples: initial support for LLMs PTQ (#658)
* Examples: WIP LLM block quantization * Add support for block zero-point * Add torch-mlir custom op support * Add test linear Signed-off-by: Alessandro Pappalardo <[email protected]> * Update to custom matmul export Signed-off-by: Alessandro Pappalardo <[email protected]> * Fix errors Signed-off-by: Alessandro Pappalardo <[email protected]> * Fix output shape of custom op Signed-off-by: Alessandro Pappalardo <[email protected]> * Add lowering to torch_mlir for single layer Signed-off-by: Alessandro Pappalardo <[email protected]> * Some cleanups * WIP llm flow Signed-off-by: Alessandro Pappalardo <[email protected]> * Fix (examples/llm): typo in custom quant matmul op (#607) * Test act equalization support * Initial end to end flow * Initial support for QuantMHA on OPT * Fix act equalization * Typos in prints * Reorganize validate * Add initial per row quantizers * Add per row input quantization support * Support group quant slicing * Adopt SliceTensor for block weight partial quant * Add float16 support * Fix scale type name * Add support for LN affine merging * WIP currently broken * Clean up weight eq support * Set weight narrow range always to False * Add fx act equalization, fixes for float16 support * Fix validate * Fix backport imports * Fix example export Signed-off-by: Alessandro Pappalardo <[email protected]> * Fix value_trace call in ln affine merging * Add per tensor/row/group dynamic scale support, some dtype improvements * Fix (llm): correct handling of attention mask shape (#652) * ALways export in fp32 base dtype on CPU * Export improvements * Fix errors after latest PR --------- Signed-off-by: Alessandro Pappalardo <[email protected]> Co-authored-by: jinchen62 <[email protected]> Co-authored-by: Giuseppe Franco <[email protected]>
- Loading branch information
1 parent
b783650
commit 51baf37
Showing
24 changed files
with
2,552 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# LLM quantization | ||
|
||
## Requirements | ||
|
||
- transformers | ||
- datasets | ||
- torch_mlir (optional for torch-mlir based export) | ||
|
||
## Run | ||
|
||
Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. | ||
|
||
```bash | ||
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse}] | ||
[--weight-scale-type {float32,po2}] [--weight-quant-type {sym,asym}] [--weight-quant-granularity {per_channel,per_tensor,per_group}] | ||
[--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] [--input-param-method {stats,mse}] | ||
[--input-scale-type {float32,po2}] [--input-quant-type {sym,asym}] [--input-quant-granularity {per_tensor}] [--quantize-input-zero-point] [--gptq] | ||
[--act-calibration] [--bias-corr] [--act-equalization] | ||
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--model MODEL HF model name. Default: facebook/opt-125m. | ||
--seed SEED Seed for sampling the calibration data. Default: 0. | ||
--nsamples NSAMPLES Number of calibration data samples. Default: 128. | ||
--seqlen SEQLEN Sequence length. Default: 2048. | ||
--eval Eval model PPL on C4. | ||
--weight-bit-width WEIGHT_BIT_WIDTH | ||
Weight bit width. Default: 8. | ||
--weight-param-method {stats,mse} | ||
How scales/zero-point are determined. Default: stats. | ||
--weight-scale-type {float32,po2} | ||
Whether scale is a float value or a po2. Default: po2. | ||
--weight-quant-type {sym,asym} | ||
Weight quantization type. Default: asym. | ||
--weight-quant-granularity {per_channel,per_tensor,per_group} | ||
Granularity for scales/zero-point of weights. Default: per_group. | ||
--weight-group-size WEIGHT_GROUP_SIZE | ||
Group size for per_group weight quantization. Default: 128. | ||
--quantize-weight-zero-point | ||
Quantize weight zero-point. | ||
--input-bit-width INPUT_BIT_WIDTH | ||
Input bit width. Default: None (disables input quantization). | ||
--input-param-method {stats,mse} | ||
How scales/zero-point are determined. Default: stats. | ||
--input-scale-type {float32,po2} | ||
Whether input scale is a float value or a po2. Default: float32. | ||
--input-quant-type {sym,asym} | ||
Input quantization type. Default: asym. | ||
--input-quant-granularity {per_tensor} | ||
Granularity for scales/zero-point of inputs. Default: per_tensor. | ||
--quantize-input-zero-point | ||
Quantize input zero-point. | ||
--gptq Apply GPTQ. | ||
--act-calibration Apply activation calibration. | ||
--bias-corr Apply bias correction. | ||
--act-equalization Apply activation equalization (SmoothQuant). | ||
--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} | ||
Model export. | ||
``` |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
""" | ||
|
||
import torch | ||
|
||
from brevitas.graph.calibrate import bias_correction_mode | ||
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn | ||
|
||
|
||
@torch.no_grad() | ||
def bias_corr_iter(curr_layer, inps, outs, cached_values): | ||
curr_layer = curr_layer.cuda() | ||
with bias_correction_mode(curr_layer): | ||
for j in range(len(inps)): | ||
inp = inps[j].unsqueeze(0).cuda() | ||
curr_out = curr_layer(inp, **cached_values)[0] | ||
outs[j] = curr_out | ||
curr_layer.cpu() | ||
return outs | ||
|
||
|
||
@torch.no_grad() | ||
def apply_bias_correction(model, dataloader, nsamples, seqlen=2048): | ||
apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=bias_corr_iter, seqlen=seqlen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
""" | ||
|
||
import torch | ||
|
||
from brevitas.graph.calibrate import calibration_mode | ||
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn | ||
|
||
|
||
@torch.no_grad() | ||
def calibration_iter(curr_layer, inps, outs, cached_values): | ||
curr_layer = curr_layer.cuda() | ||
with calibration_mode(curr_layer): | ||
for j in range(len(inps)): | ||
inp = inps[j].unsqueeze(0).cuda() | ||
curr_out = curr_layer(inp, **cached_values)[0] | ||
outs[j] = curr_out | ||
curr_layer.cpu() | ||
return outs | ||
|
||
|
||
@torch.no_grad() | ||
def apply_calibration(model, dataloader, nsamples, seqlen=2048): | ||
apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=calibration_iter, seqlen=seqlen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: | ||
Copyright 2023 IST-DASLab | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import random | ||
|
||
from datasets import load_dataset | ||
import torch | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def get_c4(nsamples, seed, seqlen, model, nvalsamples=256): | ||
traindata = load_dataset( | ||
'allenai/c4', | ||
'allenai--c4', | ||
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, | ||
split='train', | ||
use_auth_token=False) | ||
valdata = load_dataset( | ||
'allenai/c4', | ||
'allenai--c4', | ||
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, | ||
split='validation', | ||
use_auth_token=False) | ||
|
||
try: | ||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) | ||
except: | ||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
while True: | ||
i = random.randint(0, len(traindata) - 1) | ||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') | ||
if trainenc.input_ids.shape[1] >= seqlen: | ||
break | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
trainloader.append(inp) | ||
|
||
random.seed(0) # hardcoded for validation reproducibility | ||
valenc = [] | ||
for _ in range(nvalsamples): | ||
while True: | ||
i = random.randint(0, len(valdata) - 1) | ||
tmp = tokenizer(valdata[i]['text'], return_tensors='pt') | ||
if tmp.input_ids.shape[1] >= seqlen: | ||
break | ||
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
valenc.append(tmp.input_ids[:, i:j]) | ||
|
||
valenc = torch.hstack(valenc) | ||
return trainloader, valenc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
""" | ||
|
||
import warnings | ||
|
||
import torch | ||
|
||
from brevitas.fx.brevitas_tracer import value_trace | ||
from brevitas.graph.equalize import activation_equalization_mode | ||
from brevitas.graph.equalize import EqualizeGraph | ||
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn | ||
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 | ||
|
||
|
||
@torch.no_grad() | ||
def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): | ||
curr_layer = curr_layer.cuda() | ||
with activation_equalization_mode(curr_layer, alpha, add_mul_node=True, layerwise=True): | ||
for j in range(len(inps)): | ||
inp = inps[j].unsqueeze(0).cuda() | ||
curr_out = curr_layer(inp, **cached_values)[0] | ||
outs[j] = curr_out | ||
curr_layer.cpu() | ||
return outs | ||
|
||
|
||
@torch.no_grad() | ||
def apply_act_equalization( | ||
model, | ||
dtype, | ||
act_equalization_type, | ||
dataloader, | ||
nsamples, | ||
seqlen=2048, | ||
alpha=0.5, | ||
ref_kwargs=None): | ||
if act_equalization_type == 'layerwise': | ||
apply_layer_ptq_fn( | ||
model, | ||
dataloader, | ||
nsamples, | ||
inference_fn=activation_equalization_iter, | ||
seqlen=seqlen, | ||
alpha=alpha) | ||
elif act_equalization_type == 'fx': | ||
assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX." | ||
# We can't do fp16 tracing on CPU as many kernels are not implemented | ||
# So we have to cast to fp32 first, trace, apply equalization, and then cast back | ||
with cast_to_float32(model, dtype): | ||
graph_model = value_trace(model, value_args=ref_kwargs) | ||
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode | ||
# or an FX interpreter to run it on GPU | ||
warnings.warn( | ||
"FX mode activation equalization currently runs on CPU, expect it to be slow for large models." | ||
) | ||
with activation_equalization_mode(graph_model, | ||
alpha, | ||
add_mul_node=False, | ||
layerwise=False): | ||
for input_ids in dataloader: | ||
graph_model(input_ids=input_ids) | ||
else: | ||
raise RuntimeError(f"{act_equalization_type} not supported.") | ||
|
||
|
||
@torch.no_grad() | ||
def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='range'): | ||
# We can't do fp16 tracing on CPU as many kernels are not implemented | ||
# So we have to cast to fp32 first, trace, apply equalization, and then cast back | ||
with cast_to_float32(model, dtype): | ||
graph_model = value_trace(model, value_args=ref_kwargs) | ||
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: | ||
Copyright 2023 IST-DASLab | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import torch | ||
from torch import nn | ||
from tqdm import tqdm | ||
|
||
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_inference_fn | ||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl | ||
from brevitas_examples.llm.llm_quant.run_utils import InputCatcherException | ||
|
||
|
||
def eval_inference_fn(curr_layer, inps, outs, cached_values): | ||
curr_layer.cuda() | ||
for j in range(len(inps)): | ||
outs[j] = curr_layer(inps[j].unsqueeze(0).cuda(), **cached_values)[0] | ||
curr_layer.cpu() | ||
|
||
|
||
@torch.no_grad() | ||
def model_eval(model, valenc, seqlen): | ||
|
||
nsamples = valenc.numel() // seqlen | ||
|
||
def eval_input_capture_fn(model, data): | ||
for i in range(nsamples): | ||
batch = data[:, (i * seqlen):((i + 1) * seqlen)].cuda() | ||
try: | ||
model(batch) | ||
except InputCatcherException: | ||
pass | ||
|
||
inps = apply_layer_inference_fn( | ||
model, | ||
valenc, | ||
nsamples, | ||
input_capture_fn=eval_input_capture_fn, | ||
inference_fn=eval_inference_fn, | ||
seqlen=seqlen) | ||
|
||
model_impl = get_model_impl(model) | ||
use_cache = model.config.use_cache | ||
model.config.use_cache = False | ||
|
||
if hasattr(model_impl, 'norm') and model_impl.norm is not None: | ||
model_impl.norm = model_impl.norm.cuda() | ||
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: | ||
model_impl.final_layer_norm = model_impl.final_layer_norm.cuda() | ||
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: | ||
model_impl.project_out = model_impl.project_out.cuda() | ||
if hasattr(model, 'lm_head'): | ||
model.lm_head = model.lm_head.cuda() | ||
|
||
valenc = valenc.cuda() | ||
nlls = [] | ||
for i in tqdm(range(nsamples)): | ||
hidden_states = inps[i].unsqueeze(0) | ||
if hasattr(model_impl, 'norm') and model_impl.norm is not None: | ||
hidden_states = model_impl.norm(hidden_states) | ||
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: | ||
hidden_states = model_impl.final_layer_norm(hidden_states) | ||
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: | ||
hidden_states = model_impl.project_out(hidden_states) | ||
lm_logits = hidden_states | ||
if hasattr(model, 'lm_head'): | ||
lm_logits = model.lm_head(lm_logits) | ||
shift_logits = lm_logits[:, :-1, :].contiguous() | ||
shift_labels = valenc[:, (i * seqlen):((i + 1) * seqlen)][:, 1:] | ||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||
neg_log_likelihood = loss.float() * seqlen | ||
nlls.append(neg_log_likelihood) | ||
|
||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) | ||
model.config.use_cache = use_cache | ||
return ppl |
Oops, something went wrong.