diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index 85b160c3e..595428d49 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -4,79 +4,88 @@ """ from copy import deepcopy +from time import sleep + +from accelerate.utils.operations import send_to_device import torch from tqdm import tqdm from brevitas.graph.gptq import gptq_mode -from accelerate.utils.operations import send_to_device +from brevitas.graph.gpxq import StopFwdException +from brevitas.utils.python_utils import recurse_getattr @torch.no_grad() -def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None): - if True: - blocks = model.model.layers #getattr(model, block_name) +def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): + if block_name is not None: + cache_state = model.config.use_cache + model.config.use_cache = False + blocks = recurse_getattr(model, block_name) first_block = blocks[0] cached_args, cached_kwargs = [], [] + # Intercept input to first block def intercept_input(module, args, kwargs): args = send_to_device(args, 'cpu') kwargs = send_to_device(kwargs, 'cpu') cached_args.append(args) cached_kwargs.append(kwargs) - raise RuntimeError + raise StopFwdException + + # Intercept output from block N-1 to set it as input to block N def intercept_output(module, args, kwargs, output): if isinstance(output, tuple): output = output[0] + output = send_to_device(output, 'cpu') cached_args.append((output,)) - raise RuntimeError - + raise StopFwdException + # Collect input to first block hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) for inps in dataloader: try: model(**inps) - except: + except StopFwdException: pass hook.remove() - + # Iterate through all the blocks for index, block in enumerate(tqdm(blocks)): + with gptq_mode(block, - use_quant_activations=False, - group_of_parallel_layers=group_of_parallel_layers, - act_order=act_order, - create_weight_orig=False) as gptq: + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: for _ in tqdm(range(gptq.num_layers)): for args, kwargs in zip(cached_args, cached_kwargs): args = send_to_device(args, 'cuda') kwargs = send_to_device(kwargs, 'cuda') block(*args, **kwargs) - args = send_to_device(args, 'cpu') - kwargs = send_to_device(kwargs, 'cpu') gptq.update() - past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) - cached_args = [] - if index < len(blocks)-1: - hook = blocks[index].register_forward_hook(intercept_output, with_kwargs=True) + if index < len(blocks) - 1: + # Once the block is done, we need to update the input to the next block + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + hook = block.register_forward_hook(intercept_output, with_kwargs=True) for args, kwargs in zip(past_cached_args, past_cached_kwargs): try: args = send_to_device(args, 'cuda') kwargs = send_to_device(kwargs, 'cuda') block(*args, **kwargs) - args = send_to_device(args, 'cpu') - kwargs = send_to_device(kwargs, 'cpu') - except Exception as e: + except StopFwdException: pass hook.remove() - + # Restore cache state + model.config.use_cache = cache_state else: with gptq_mode(model, - use_quant_activations=False, - group_of_parallel_layers=group_of_parallel_layers, - act_order=act_order, - create_weight_orig=False) as gptq: + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e19390774..a3b447cbf 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -113,7 +113,6 @@ def validate(args): def main(args): validate(args) set_seed(args.seed) - if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" @@ -325,6 +324,13 @@ def parse_args(args): choices=['wikitext2', 'c4'], default='wikitext2', help='Dataset to use for quantization (default: %(default)s)') + parser.add_argument( + '--gptq-block-name', + type=str, + default=None, + help= + 'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)' + ) parser.add_argument( '--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') parser.add_argument(