From 58d6f152a19a4bc31a25988df4c68e61d29292e0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 8 Oct 2024 13:38:23 +0100 Subject: [PATCH 1/5] Block gptq --- src/brevitas_examples/llm/llm_quant/gpxq.py | 116 +++++++++++++++----- src/brevitas_examples/llm/main.py | 8 +- 2 files changed, 98 insertions(+), 26 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index e2bfba989..3a7d732f9 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -3,39 +3,105 @@ # SPDX-License-Identifier: BSD-3-Clause """ + +from copy import deepcopy + +from accelerate.utils.operations import send_to_device import torch from tqdm import tqdm from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gptq import gptq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.utils.python_utils import recurse_getattr + +@torch.no_grad() +def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs): + cache_state = model.config.use_cache + model.config.use_cache = False + blocks = recurse_getattr(model, block_name) + first_block = blocks[0] + cached_args, cached_kwargs = [], [] + + # Intercept input to first block + def intercept_input(module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise StopFwdException + + # Intercept output from block N-1 to set it as input to block N + def intercept_output(module, args, kwargs, output): + if isinstance(output, tuple): + output = output[0] + output = send_to_device(output, 'cpu') + cached_args.append((output,)) + raise StopFwdException + + # Collect input to first block + hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) + for inps in dataloader: + try: + model(**inps) + except StopFwdException: + pass + hook.remove() + + # Iterate through all the blocks + for index, block in enumerate(tqdm(blocks)): + with context_manager_func(block, **context_manager_kwargs) as gpxq: + for _ in tqdm(range(gpxq.num_layers)): + for args, kwargs in zip(cached_args, cached_kwargs): + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + gpxq.update() + if index < len(blocks) - 1: + # Once the block is done, we need to update the input to the next block + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + hook = block.register_forward_hook(intercept_output, with_kwargs=True) + for args, kwargs in zip(past_cached_args, past_cached_kwargs): + try: + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + except StopFwdException: + pass + hook.remove() + # Restore cache state + model.config.use_cache = cache_state @torch.no_grad() -def apply_gptq( - model, - dataloader, - act_order=True, - group_of_parallel_layers=None, - use_quant_activations=True, - create_weight_orig=False): - with gptq_mode(model, - act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers, - use_quant_activations=use_quant_activations, - create_weight_orig=create_weight_orig) as gptq: - gptq_model = gptq.model - for _ in tqdm(range(gptq.num_layers)): - for inps in dataloader: - gptq_model(**inps) - gptq.update() +def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): + if block_name is not None: + context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False} + block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) + else: + with gptq_mode(model, + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: + gptq_model = gptq.model + for _ in tqdm(range(gptq.num_layers)): + for inps in dataloader: + gptq_model(**inps) + gptq.update() @torch.no_grad() -def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None): - with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: - gpfq_model = gpfq.model - for _ in tqdm(range(gpfq.num_layers)): - for inps in dataloader: - gpfq_model(**inps) - gpfq.update() +def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): + if block_name is not None: + context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False} + block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) + else: + with gpfq_mode(model, act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers) as gpfq: + gpfq_model = gpfq.model + for _ in tqdm(range(gpfq.num_layers)): + for inps in dataloader: + gpfq_model(**inps) + gpfq.update() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c33de54c8..f7ca0810d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -118,7 +118,6 @@ def validate(args): def main(args): validate(args) set_seed(args.seed) - if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" @@ -340,6 +339,13 @@ def parse_args(args): choices=['wikitext2', 'c4'], default='wikitext2', help='Dataset to use for quantization (default: %(default)s)') + parser.add_argument( + '--gptq-block-name', + type=str, + default=None, + help= + 'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)' + ) parser.add_argument( '--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') parser.add_argument( From da1a24914ae1c711f4383c46cbfbc4c711403393 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 11 Oct 2024 22:23:45 +0100 Subject: [PATCH 2/5] GPxQ generalization --- src/brevitas_examples/llm/main.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f7ca0810d..bf995a426 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -188,6 +188,8 @@ def main(args): if require_fx: model = get_fx(model) + # Blockwise optimization does not work with FX at the moment + args.gpxq_block_name = None # Apply LN affine merging before inserting MHA layers # since currently there is support only for merging into Linear @@ -284,12 +286,17 @@ def main(args): calibration_loader, act_order=args.gpxq_act_order, use_quant_activations=args.gpxq_use_quant_activations, - create_weight_orig=args.gpxq_create_weight_orig) + create_weight_orig=args.gpxq_create_weight_orig, + block_name=args.gpxq_block_name) print("GPTQ applied.") if args.gpfq: print("Applying GPFQ...") - apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order) + apply_gpfq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + block_name=args.gpxq_block_name) print("GPFQ applied.") if args.bias_corr: @@ -340,11 +347,11 @@ def parse_args(args): default='wikitext2', help='Dataset to use for quantization (default: %(default)s)') parser.add_argument( - '--gptq-block-name', + '--gpxq-block-name', type=str, default=None, help= - 'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)' + 'Block name for faster GPxQ optimization. It works only if FX is not needed (default: %(default)s)' ) parser.add_argument( '--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') From 0732d3da08816fab54d970eac9995a14114f1c83 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 20:22:05 +0000 Subject: [PATCH 3/5] Fix (gpxq): connecting args to gpxq_mode --- src/brevitas_examples/llm/llm_quant/gpxq.py | 134 +++++++++++--------- 1 file changed, 75 insertions(+), 59 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 3a7d732f9..7867da3df 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause """ - from copy import deepcopy from accelerate.utils.operations import send_to_device @@ -15,76 +14,89 @@ from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr + @torch.no_grad() def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs): - cache_state = model.config.use_cache - model.config.use_cache = False - blocks = recurse_getattr(model, block_name) - first_block = blocks[0] - cached_args, cached_kwargs = [], [] + cache_state = model.config.use_cache + model.config.use_cache = False + blocks = recurse_getattr(model, block_name) + first_block = blocks[0] + cached_args, cached_kwargs = [], [] - # Intercept input to first block - def intercept_input(module, args, kwargs): - args = send_to_device(args, 'cpu') - kwargs = send_to_device(kwargs, 'cpu') - cached_args.append(args) - cached_kwargs.append(kwargs) - raise StopFwdException + # Intercept input to first block + def intercept_input(module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise StopFwdException - # Intercept output from block N-1 to set it as input to block N - def intercept_output(module, args, kwargs, output): - if isinstance(output, tuple): - output = output[0] - output = send_to_device(output, 'cpu') - cached_args.append((output,)) - raise StopFwdException + # Intercept output from block N-1 to set it as input to block N + def intercept_output(module, args, kwargs, output): + if isinstance(output, tuple): + output = output[0] + output = send_to_device(output, 'cpu') + cached_args.append((output,)) + raise StopFwdException - # Collect input to first block - hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) - for inps in dataloader: - try: - model(**inps) - except StopFwdException: - pass - hook.remove() + # Collect input to first block + hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) + for inps in dataloader: + try: + model(**inps) + except StopFwdException: + pass + hook.remove() - # Iterate through all the blocks - for index, block in enumerate(tqdm(blocks)): - with context_manager_func(block, **context_manager_kwargs) as gpxq: - for _ in tqdm(range(gpxq.num_layers)): - for args, kwargs in zip(cached_args, cached_kwargs): - args = send_to_device(args, 'cuda') - kwargs = send_to_device(kwargs, 'cuda') - block(*args, **kwargs) - gpxq.update() + # Iterate through all the blocks + for index, block in tqdm(enumerate(blocks), desc="Blocks", total=len(blocks)): + with context_manager_func(block, **context_manager_kwargs) as gpxq: + for _ in tqdm(range(gpxq.num_layers), desc="Layers", leave=False): + for args, kwargs in zip(cached_args, cached_kwargs): + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + gpxq.update() + + if index < len(blocks) - 1: + # Once the block is done, we need to update the input to the next block + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + hook = block.register_forward_hook(intercept_output, with_kwargs=True) + for args, kwargs in zip(past_cached_args, past_cached_kwargs): + try: + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + except StopFwdException: + pass + hook.remove() + # Restore cache state + model.config.use_cache = cache_state - if index < len(blocks) - 1: - # Once the block is done, we need to update the input to the next block - past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) - cached_args = [] - hook = block.register_forward_hook(intercept_output, with_kwargs=True) - for args, kwargs in zip(past_cached_args, past_cached_kwargs): - try: - args = send_to_device(args, 'cuda') - kwargs = send_to_device(kwargs, 'cuda') - block(*args, **kwargs) - except StopFwdException: - pass - hook.remove() - # Restore cache state - model.config.use_cache = cache_state @torch.no_grad() -def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): +def apply_gptq( + model, + dataloader, + act_order=True, + use_quant_activations=False, + create_weight_orig=False, + group_of_parallel_layers=None, + block_name=None): if block_name is not None: - context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False} + context_manager_kwargs = { + 'act_order': act_order, + 'group_of_parallel_layers': group_of_parallel_layers, + 'create_weight_orig': create_weight_orig, + 'use_quant_activations': use_quant_activations} block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) else: with gptq_mode(model, - use_quant_activations=False, + use_quant_activations=use_quant_activations, group_of_parallel_layers=group_of_parallel_layers, act_order=act_order, - create_weight_orig=False) as gptq: + create_weight_orig=create_weight_orig) as gptq: gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: @@ -95,11 +107,15 @@ def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, @torch.no_grad() def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): if block_name is not None: - context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False} + context_manager_kwargs = { + 'act_order': act_order, + 'group_of_parallel_layers': group_of_parallel_layers, + 'create_weight_orig': True} block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) else: - with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: + with gpfq_mode(model, + act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: From 0e91f52e401034cc06c895e5e7770b3994dfa7a5 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 22:16:29 +0000 Subject: [PATCH 4/5] Fix (gpfq): force create_weight_orig to True --- src/brevitas_examples/llm/llm_quant/gpxq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 7867da3df..022d16f70 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -115,7 +115,8 @@ def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, else: with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: + group_of_parallel_layers=group_of_parallel_layers, + create_weight_orig=True) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: From 2f6619de2eddde6e8f0658abd1f7671d921b07ec Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 17 Oct 2024 15:32:12 +0100 Subject: [PATCH 5/5] Fix for quant_act GPTQ and disable GPFQ block --- src/brevitas_examples/llm/llm_quant/gpxq.py | 155 ++++++++++++-------- 1 file changed, 93 insertions(+), 62 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 3a7d732f9..6269f0060 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -3,81 +3,112 @@ # SPDX-License-Identifier: BSD-3-Clause """ - from copy import deepcopy from accelerate.utils.operations import send_to_device import torch from tqdm import tqdm +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gptq import gptq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr + @torch.no_grad() def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs): - cache_state = model.config.use_cache - model.config.use_cache = False - blocks = recurse_getattr(model, block_name) - first_block = blocks[0] - cached_args, cached_kwargs = [], [] - - # Intercept input to first block - def intercept_input(module, args, kwargs): - args = send_to_device(args, 'cpu') - kwargs = send_to_device(kwargs, 'cpu') - cached_args.append(args) - cached_kwargs.append(kwargs) - raise StopFwdException - - # Intercept output from block N-1 to set it as input to block N - def intercept_output(module, args, kwargs, output): - if isinstance(output, tuple): - output = output[0] - output = send_to_device(output, 'cpu') - cached_args.append((output,)) - raise StopFwdException - - # Collect input to first block - hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) - for inps in dataloader: - try: - model(**inps) - except StopFwdException: - pass - hook.remove() - - # Iterate through all the blocks - for index, block in enumerate(tqdm(blocks)): - with context_manager_func(block, **context_manager_kwargs) as gpxq: - for _ in tqdm(range(gpxq.num_layers)): - for args, kwargs in zip(cached_args, cached_kwargs): - args = send_to_device(args, 'cuda') - kwargs = send_to_device(kwargs, 'cuda') - block(*args, **kwargs) - gpxq.update() - - if index < len(blocks) - 1: - # Once the block is done, we need to update the input to the next block - past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) - cached_args = [] - hook = block.register_forward_hook(intercept_output, with_kwargs=True) - for args, kwargs in zip(past_cached_args, past_cached_kwargs): - try: - args = send_to_device(args, 'cuda') - kwargs = send_to_device(kwargs, 'cuda') - block(*args, **kwargs) - except StopFwdException: - pass - hook.remove() - # Restore cache state - model.config.use_cache = cache_state + disable_quant_inference = DisableEnableQuantization() + cache_state = model.config.use_cache + model.config.use_cache = False + blocks = recurse_getattr(model, block_name) + first_block = blocks[0] + cached_args, cached_kwargs = [], [] + + # Intercept input to first block + def intercept_input(module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise StopFwdException + + # Intercept output from block N-1 to set it as input to block N + def intercept_output(module, args, kwargs, output): + if isinstance(output, tuple): + output = output[0] + output = send_to_device(output, 'cpu') + cached_args.append((output,)) + raise StopFwdException + + # Collect input to first block + if not context_manager_kwargs.get('use_quant_activations', True): + return_quant_tensor_state = disable_return_quant_tensor(model) + disable_quant_inference.disable_act_quantization(model, is_training=model.training) + disable_quant_inference.disable_bias_quantization(model, is_training=model.training) + + hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) + for inps in dataloader: + try: + model(**inps) + except StopFwdException: + pass + hook.remove() + + if not context_manager_kwargs.get('use_quant_activations', True): + disable_quant_inference.enable_act_quantization(model, is_training=model.training) + disable_quant_inference.enable_bias_quantization(model, is_training=model.training) + restore_return_quant_tensor(model, return_quant_tensor_state) + + # Iterate through all the blocks + for index, block in enumerate(tqdm(blocks)): + with context_manager_func(block, **context_manager_kwargs) as gpxq: + for _ in tqdm(range(gpxq.num_layers)): + for args, kwargs in zip(cached_args, cached_kwargs): + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + gpxq.update() + + if index < len(blocks) - 1: + # Once the block is done, we need to update the input to the next block + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + hook = block.register_forward_hook(intercept_output, with_kwargs=True) + + if not context_manager_kwargs.get('use_quant_activations', True): + return_quant_tensor_state = disable_return_quant_tensor(model) + disable_quant_inference.disable_act_quantization(model, is_training=model.training) + disable_quant_inference.disable_bias_quantization(model, is_training=model.training) + + for args, kwargs in zip(past_cached_args, past_cached_kwargs): + try: + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + except StopFwdException: + pass + + if not context_manager_kwargs.get('use_quant_activations', True): + disable_quant_inference.enable_act_quantization(model, is_training=model.training) + disable_quant_inference.enable_bias_quantization(model, is_training=model.training) + restore_return_quant_tensor(model, return_quant_tensor_state) + + hook.remove() + # Restore cache state + model.config.use_cache = cache_state + @torch.no_grad() def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): if block_name is not None: - context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False} + context_manager_kwargs = { + 'act_order': act_order, + 'group_of_parallel_layers': group_of_parallel_layers, + 'create_weight_orig': False, + 'use_quant_activations': False} block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) else: with gptq_mode(model, @@ -95,11 +126,11 @@ def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, @torch.no_grad() def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): if block_name is not None: - context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False} - block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) + raise RuntimeError("Block optimization not support for GPFQ at the moment") else: - with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: + with gpfq_mode(model, + act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: