diff --git a/src/brevitas_examples/llm/llm_quant/bias_corr.py b/src/brevitas_examples/llm/llm_quant/bias_corr.py index 6540af44b..900777874 100644 --- a/src/brevitas_examples/llm/llm_quant/bias_corr.py +++ b/src/brevitas_examples/llm/llm_quant/bias_corr.py @@ -6,21 +6,10 @@ import torch from brevitas.graph.calibrate import bias_correction_mode -from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn - - -@torch.no_grad() -def bias_corr_iter(curr_layer, inps, outs, cached_values): - curr_layer = curr_layer.cuda() - with bias_correction_mode(curr_layer): - for j in range(len(inps)): - inp = inps[j].unsqueeze(0).cuda() - curr_out = curr_layer(inp, **cached_values)[0] - outs[j] = curr_out - curr_layer.cpu() - return outs @torch.no_grad() def apply_bias_correction(model, dataloader): - apply_layer_ptq_fn(model, dataloader, inference_fn=bias_corr_iter) + with bias_correction_mode(curr_layer): + for inps in dataloader: + model(**inps) diff --git a/src/brevitas_examples/llm/llm_quant/calibrate.py b/src/brevitas_examples/llm/llm_quant/calibrate.py index a9153fb4b..93f038d22 100644 --- a/src/brevitas_examples/llm/llm_quant/calibrate.py +++ b/src/brevitas_examples/llm/llm_quant/calibrate.py @@ -7,28 +7,10 @@ from tqdm import tqdm from brevitas.graph.calibrate import calibration_mode -from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn -from brevitas_examples.optimum.utils import offload_model -from brevitas_examples.optimum.utils import remove_hooks @torch.no_grad() -def calibration_iter(curr_layer, inps, outs, cached_values): - curr_layer = curr_layer.cuda() - with calibration_mode(curr_layer): - for j in range(len(inps)): - inp = inps[j].unsqueeze(0).cuda() - curr_out = curr_layer(inp, **cached_values)[0] - outs[j] = curr_out - curr_layer.cpu() - return outs - - -@torch.no_grad() -def apply_calibration(model, dataloader, forward_call): - model = offload_model(model) +def apply_calibration(model, dataloader): with calibration_mode(model): for inps in tqdm(dataloader): - forward_call(model, inps) - # Remove all accelerate hooks - remove_hooks(model) + model(**inps) diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index 536b08560..3ba4ac529 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -3,35 +3,19 @@ # SPDX-License-Identifier: BSD-3-Clause """ -from accelerate.hooks import remove_hook_from_module import torch from tqdm import tqdm from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph -from brevitas_examples.optimum.utils import offload_model -from brevitas_examples.optimum.utils import remove_hooks @torch.no_grad() -def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): - curr_layer = curr_layer.cuda() - with activation_equalization_mode(curr_layer, alpha, add_mul_node=True, layerwise=True): - for j in range(len(inps)): - inp = inps[j].unsqueeze(0).cuda() - curr_out = curr_layer(inp, **cached_values)[0] - outs[j] = curr_out - curr_layer.cpu() - return outs - - -@torch.no_grad() -def apply_act_equalization(model, act_equalization_type, dataloader, forward_call, alpha=0.5): - model = offload_model(model) +def apply_act_equalization(model, act_equalization_type, dataloader, alpha=0.5): if act_equalization_type == 'layerwise': with activation_equalization_mode(model, alpha, add_mul_node=True, layerwise=True): for inps in tqdm(dataloader): - forward_call(model, inps) + model(**inps) elif act_equalization_type == 'fx': assert model is not None, "FX Model is required to perform FX SmoothQuant" @@ -41,12 +25,10 @@ def apply_act_equalization(model, act_equalization_type, dataloader, forward_cal layerwise=False, co_optimize_act_weights=True): for inps in tqdm(dataloader): - forward_call(model, inps) + model(**inps) else: raise RuntimeError(f"{act_equalization_type} not supported.") - # Remove all accelerate hooks - remove_hooks(model) @torch.no_grad() diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index 91386207a..14e74496b 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -20,91 +20,29 @@ from torch import nn from tqdm import tqdm -from brevitas_examples.llm.llm_quant.run_utils import apply_layer_inference_fn -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl -from brevitas_examples.llm.llm_quant.run_utils import InputCatcherException - -def eval_inference_fn(curr_layer, inps, outs, cached_values): - curr_layer.cuda() - for j in range(len(inps)): - outs[j] = curr_layer(inps[j].unsqueeze(0).cuda(), **cached_values)[0] - curr_layer.cpu() - - -@torch.no_grad() -def model_eval(model, valenc, seqlen): - - nsamples = valenc.numel() // seqlen - - def eval_input_capture_fn(model, data): - for i in range(nsamples): - batch = data[:, (i * seqlen):((i + 1) * seqlen)].cuda() - try: - model(batch) - except InputCatcherException: - pass - - inps = apply_layer_inference_fn( - model, - valenc, - input_capture_fn=eval_input_capture_fn, - inference_fn=eval_inference_fn, - ) - - model_impl = get_model_impl(model) - use_cache = model.config.use_cache - model.config.use_cache = False - - if hasattr(model_impl, 'norm') and model_impl.norm is not None: - model_impl.norm = model_impl.norm.cuda() - if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: - model_impl.final_layer_norm = model_impl.final_layer_norm.cuda() - if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: - model_impl.project_out = model_impl.project_out.cuda() - if hasattr(model, 'lm_head'): - model.lm_head = model.lm_head.cuda() - - valenc = valenc.cuda() - nlls = [] +def create_validation_dataloader(data, seqlen): + nsamples = data['input_ids'].numel() // seqlen + val_dataloader = [] for i in tqdm(range(nsamples)): - hidden_states = inps[i].unsqueeze(0) - if hasattr(model_impl, 'norm') and model_impl.norm is not None: - hidden_states = model_impl.norm(hidden_states) - if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: - hidden_states = model_impl.final_layer_norm(hidden_states) - if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: - hidden_states = model_impl.project_out(hidden_states) - lm_logits = hidden_states - if hasattr(model, 'lm_head'): - lm_logits = model.lm_head(lm_logits) - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = valenc[:, (i * seqlen):((i + 1) * seqlen)][:, 1:] - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - neg_log_likelihood = loss.float() * seqlen - nlls.append(neg_log_likelihood) - - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) - model.config.use_cache = use_cache - return ppl + batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda() + attention_mask = torch.ones_like(batch) + val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask}) + return val_dataloader @torch.no_grad() -def model_eval_accelerate(model, valenc, seqlen, forward_call): +def model_eval(model, valenc, seqlen): - nsamples = valenc['input_ids'].numel() // seqlen + nsamples = len(valenc) use_cache = model.config.use_cache model.config.use_cache = False with torch.no_grad(): nlls = [] - for i in tqdm(range(nsamples)): - batch = valenc['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda() - attention_mask = torch.ones_like(batch) - lm_logits = forward_call(model, { - 'input_ids': batch, 'attention_mask': attention_mask})['logits'] + for inps in valenc: + lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = (valenc['input_ids'][:, (i * seqlen):((i + 1) * seqlen)][:, 1:]).cuda() + shift_labels = inps['input_ids'][:, 1:].cuda() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index 191775e16..1eafa2851 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -3,37 +3,14 @@ # SPDX-License-Identifier: BSD-3-Clause """ -from accelerate.hooks import remove_hook_from_module import torch from tqdm import tqdm from brevitas.graph.gptq import gptq_mode -from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn -from brevitas_examples.optimum.utils import offload_model -from brevitas_examples.optimum.utils import remove_hooks @torch.no_grad() -def gptq_iter(curr_layer, inps, outs, cached_values, act_order): - curr_layer = curr_layer.cuda() - with gptq_mode(curr_layer, use_quant_activations=False, act_order=act_order) as gptq: - gptq_layer = gptq.model - for _ in range(gptq.num_layers): - for j in range(len(inps)): - curr_inp = inps[j].unsqueeze(0).cuda() - gptq_layer(curr_inp, **cached_values) - gptq.update() - for j in range(len(inps)): - inp = inps[j].unsqueeze(0).cuda() - curr_out = curr_layer(inp, **cached_values)[0] - outs[j] = curr_out - curr_layer.cpu() - return outs - - -@torch.no_grad() -def apply_gptq(model, dataloader, forward_call, act_order=True, group_of_parallel_layers=None): - model = offload_model(model) +def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None): with gptq_mode(model, use_quant_activations=False, group_of_parallel_layers=group_of_parallel_layers, @@ -42,7 +19,5 @@ def apply_gptq(model, dataloader, forward_call, act_order=True, group_of_paralle gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: - forward_call(gptq_model, inps) + gptq_model(**inps) gptq.update() - # Remove all accelerate hooks - remove_hooks(model) diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index 9c9a6636a..b0a2e67f2 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -20,179 +20,49 @@ """ from contextlib import contextmanager -import functools +import inspect +from optimum.utils.normalized_config import NormalizedConfigManager import torch -from torch import nn from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map -from tqdm import tqdm -from transformers.models.opt.modeling_opt import OPTModel +from transformers import AutoConfig from transformers.utils.fx import symbolic_trace -from brevitas.fx.brevitas_tracer import value_trace from brevitas.fx.value_tracer import ValueProxy -from brevitas.graph.standardize import TorchFunctionalToModule -from brevitas.utils.python_utils import recurse_getattr -BLOCK_PATTERNS = [ - "transformer.h", - "model.decoder.layers", - "gpt_neox.layers", - "model.layers",] +def get_fx(model): + forward_signature = inspect.signature(model.forward).parameters + if all(input_name in forward_signature + for input_name in ["input_ids", "attention_mask", "past_key_values"]): + input_names = ["input_ids", "attention_mask", "past_key_values"] + else: + raise ValueError( + f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}" + ) -def get_fx_graph(model, ref_kwargs=None, dtype=None): - try: - graph_model = symbolic_trace(model, list(ref_kwargs.keys())) - except: - assert ref_kwargs is not None, "Symbolic traced failed, pass an example input to perform FX value trace " - with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) - - graph_model = TorchFunctionalToModule().apply(graph_model) - return graph_model - - -def get_preceding_modules(model: nn.Module, module_name: str): - # From https://github.com/huggingface/optimum/blob/main/optimum/gptq/utils.py - previous_module_name = [] - stop_adding = False - - def _get_preceding_modules(model: nn.Module, module_name: str, name: str = ""): - nonlocal stop_adding - for name_bis, child in model.named_children(): - new_name = name + "." + name_bis if name != "" else name_bis - if new_name == module_name: - stop_adding = True - break - _get_preceding_modules(child, module_name, name=new_name) - if not stop_adding: - previous_module_name.append(name) - return previous_module_name - - return _get_preceding_modules(model, module_name) - - -def get_block_name_with_pattern(model: nn.Module): - """ - From: https://github.com/huggingface/optimum/blob/main/optimum/gptq/utils.py - Get the name of the module that contains the transformers blocks by checking if any modules has a specific pattern - - Args: - model (`nn.Module`): - The input model - Returns: - `str`: The name of the module that contains the Transformer blocks. - """ - modules_names = [n for n, _ in model.named_modules()] - for pattern_candidate in BLOCK_PATTERNS: - pattern_candidate = pattern_candidate - if any(pattern_candidate in name for name in modules_names): - return pattern_candidate - raise ValueError( - "Block pattern could not be match. Pass `block_name_to_quantize` argument in `quantize_model`" - ) - - -def get_model_impl(model): - model_impl = model.model - if isinstance(model_impl, OPTModel): - model_impl = model_impl.decoder - return model_impl - - -class InputCatcherException(Exception): - pass - - -@torch.no_grad() -def calib_input_capture(model, dataloader): - for batch in dataloader: - batch = batch.cuda() - try: - model(batch) - except InputCatcherException: - pass - - -@torch.no_grad() -def capture_first_layer_inputs(input_capture_fn, dataloader, model, layers, preceding_layers_name): - - for module_name in preceding_layers_name: - module = recurse_getattr(model, module_name) - if module is None: - raise ValueError(f"Module {module_name} was not found in model") - module = module.cuda() - - dtype = next(iter(model.parameters())).dtype - inps = [] - cache = {'i': 0} - - class InputCatcher(nn.Module): - - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps.append(inp) - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - if 'position_ids' in kwargs.keys(): - cache['position_ids'] = kwargs['position_ids'] - raise InputCatcherException - - layers[0] = InputCatcher(layers[0]) - input_capture_fn(model, dataloader) - inps = torch.cat(inps, dim=0).cuda().to(dtype) - - layers[0] = layers[0].module - - for module_name in preceding_layers_name: - module = recurse_getattr(model, module_name) - if module is None: - raise ValueError(f"Module {module_name} was not found in model") - module = module.cpu() - return inps, cache - - -@torch.no_grad() -def apply_layer_inference_fn( - model, dataloader, inference_fn, input_capture_fn, block_name=None, **inference_fn_kwargs): - if block_name is None: - block_name = get_block_name_with_pattern(model) - - layers = recurse_getattr(model, block_name) - module_name_preceding_first_block = get_preceding_modules(model, block_name) - - use_cache = model.config.use_cache - model.config.use_cache = False - - inps, cache = capture_first_layer_inputs( - input_capture_fn, dataloader, model, layers, module_name_preceding_first_block) - outs = torch.zeros_like(inps) - - cached_values = {} - cached_values['attention_mask'] = cache['attention_mask'] - if 'position_ids' in cache.keys(): - cached_values['position_ids'] = cache['position_ids'] - - for curr_layer in tqdm(layers): - inference_fn(curr_layer, inps, outs, cached_values, **inference_fn_kwargs) - inps, outs = outs, inps - - model.config.use_cache = use_cache - return inps - - -def apply_layer_ptq_fn(model, dataloader, inference_fn, **inference_fn_kwargs): - return apply_layer_inference_fn( - model, - dataloader, - inference_fn, - input_capture_fn=calib_input_capture, - **inference_fn_kwargs) + with torch.no_grad(): + model = symbolic_trace(model, input_names) + return model + + +def modify_dataloader(model_name_or_path, data): + config = AutoConfig.from_pretrained(model_name_or_path) + + normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type) + normalized_config = normalized_config_class(config) + + num_heads = normalized_config.num_attention_heads + head_dim = normalized_config.hidden_size // num_heads + num_layers = normalized_config.num_layers + + for sample in data: + sample["past_key_values"] = tuple(( + torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device), + torch.zeros(1, num_heads, 0, head_dim, device=sample["input_ids"].device), + ) for _ in range(num_layers)) + return data @contextmanager diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 73a44be98..dec124958 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -7,11 +7,16 @@ import re import numpy as np +from optimum.amd.brevitas.accelerate_utils import offload_model +from optimum.amd.brevitas.accelerate_utils import remove_hooks +from optimum.exporters.onnx.__main__ import onnx_export import torch from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction @@ -20,13 +25,16 @@ from brevitas_examples.llm.llm_quant.data import get_wikitext2 from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization +from brevitas_examples.llm.llm_quant.eval import create_validation_dataloader from brevitas_examples.llm.llm_quant.eval import model_eval +from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gptq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 -from brevitas_examples.llm.llm_quant.run_utils import get_fx_graph -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl +from brevitas_examples.llm.llm_quant.run_utils import get_fx +from brevitas_examples.llm.llm_quant.run_utils import modify_dataloader parser = argparse.ArgumentParser() parser.add_argument( @@ -56,7 +64,7 @@ parser.add_argument( '--weight-quant-type', type=str, - default='asym', + default='sym', choices=['sym', 'asym'], help='Weight quantization type. Default: asym.') parser.add_argument( @@ -68,7 +76,7 @@ parser.add_argument( '--weight-quant-granularity', type=str, - default='per_tensor', + default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') parser.add_argument( @@ -117,7 +125,7 @@ parser.add_argument( '--input-quant-granularity', type=str, - default='per_group', + default='per_tensor', choices=['per_tensor', 'per_row', 'per_group'], help='Granularity for scales/zero-point of inputs. Default: per_tensor.') parser.add_argument( @@ -179,7 +187,18 @@ def model_export(model, ref_input, args): sharded_weight_group_export sharded_weight_group_export(model, no_custom_packed_export=False) elif args.export_target == 'onnx_qcdq': - export_onnx_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx") + if args.weight_quant_granularity == 'per_group': + export_manager = BlockQuantProxyLevelManager + else: + export_manager = StdQCDQONNXManager + export_manager.change_weight_export(export_weight_q_node=True) + print(f"Exporting the model in ./quantized_onnx/{args.model.replace('/', '-')}") + with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager): + onnx_export( + model, + f"./quantized_onnx/{args.model.replace('/', '-')}", + task="text-generation-with-past", + do_validation=False) elif args.export_target == 'torch_qcdq': export_torch_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.pt") @@ -199,7 +218,8 @@ def validate(args): assert args.input_bit_width is None, "Sharded packed torch group weight export doesn't support input quant." assert not args.quantize_weight_zero_point, "Quantized weight zero point not supported." if args.export_target == 'onnx_qcdq': - assert args.weight_quant_granularity != 'per_group', "ONNX QCDQ export doesn't support group weight quantization." + if args.weight_quant_granularity == 'per_group': + assert args.input_bit_width is None, "ONNX QCDQ per_group quantization requires no input quantization" if args.weight_quant_type == 'asym': assert args.quantize_weight_zero_point, "Quantized weight zero point required." if args.input_bit_width is not None and args.input_quant_type == 'asym': @@ -231,6 +251,7 @@ def main(): model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) print("Model loaded.") model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) if args.load_awq: from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq @@ -241,8 +262,15 @@ def main(): if (args.export_target or args.eval or args.act_equalization or args.act_calibration or args.gptq or args.bias_corr or args.ln_affine_merge or args.weight_equalization): print("Data loading...") - calibration_loader, val_data = get_wikitext2( - nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=args.seqlen) + calibration_loader = get_wikitext2( + nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) + val_data = get_wikitext2( + nsamples=args.nsamples, + tokenizer=tokenizer, + seqlen=args.seqlen, + split='validation', + seed=0) + val_data = create_validation_dataloader(val_data, args.seqlen) print("Data loaded.") # Apply LN affine merging before inserting MHA layers @@ -259,28 +287,27 @@ def main(): model = replace_mha_with_quantizable_layers(model, dtype) print("Replacing done.") - graph_model = get_fx_graph(model, ref_kwargs={'input_ids': calibration_loader[0]}, dtype=dtype) + if args.weight_equalization or args.act_equalization == 'fx': + model = get_fx(model) + calibration_loader = modify_dataloader(args.model, calibration_loader) + val_data = modify_dataloader(args.model, val_data) if args.weight_equalization: print("Apply weight equalization...") - apply_weight_equalization(graph_model) + apply_weight_equalization(model) print("Weight equalization applied.") if args.act_equalization is not None: + offload_model(model) print("Apply act equalization (SmoothQuant)...") - apply_act_equalization( - model, args.act_equalization, calibration_loader, graph_model=graph_model) + apply_act_equalization(model, args.act_equalization, calibration_loader) print("Act equalization applied.") - - if args.quantize_embedding or args.quantize_last_layer: - layers_to_quantize = model - else: - layers_to_quantize = get_model_impl(model).layers + remove_hooks(model) if not args.no_quantize: print("Applying model quantization...") quantize_model( - layers_to_quantize, + model, dtype=dtype, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, @@ -299,8 +326,7 @@ def main(): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - quantize_embedding=args.quantize_embedding, - seqlen=args.seqlen) + quantize_embedding=args.quantize_embedding) # Tie back first/last layer weights in case they got untied print("Model quantization applied.") @@ -310,6 +336,7 @@ def main(): if args.act_equalization is None and not args.weight_equalization: model.tie_weights() + model = offload_model(model) if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -329,6 +356,7 @@ def main(): print("Model eval...") ppl = model_eval(model, val_data, args.seqlen) print(f"C4 perplexity: {ppl}") + remove_hooks(model) if args.export_target: print(f"Export to {args.export_target}")