Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/generative): block-based optimization for GPTQ #1046

Merged
merged 6 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 125 additions & 20 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,144 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from copy import deepcopy

from accelerate.utils.operations import send_to_device
import torch
from tqdm import tqdm

from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr


@torch.no_grad()
def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs):
disable_quant_inference = DisableEnableQuantization()
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise StopFwdException

# Collect input to first block
if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except StopFwdException:
pass
hook.remove()

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

# Iterate through all the blocks
for index, block in tqdm(enumerate(blocks), desc="Blocks", total=len(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers), desc="Layers", leave=False):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
gpxq.update()

if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)

if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
except StopFwdException:
pass

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

hook.remove()
# Restore cache state
model.config.use_cache = cache_state


@torch.no_grad()
def apply_gptq(
model,
dataloader,
act_order=True,
use_quant_activations=False,
create_weight_orig=False,
group_of_parallel_layers=None,
use_quant_activations=True,
create_weight_orig=False):
with gptq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers,
use_quant_activations=use_quant_activations,
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()
block_name=None):
if block_name is not None:
context_manager_kwargs = {
'act_order': act_order,
'group_of_parallel_layers': group_of_parallel_layers,
'create_weight_orig': create_weight_orig,
'use_quant_activations': use_quant_activations}
block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
else:
with gptq_mode(model,
use_quant_activations=use_quant_activations,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()


@torch.no_grad()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None):
with gpfq_mode(model, act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
raise RuntimeError("Block optimization not support for GPFQ at the moment")
else:
with gpfq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers,
create_weight_orig=True) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
19 changes: 16 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def validate(args):
def main(args):
validate(args)
set_seed(args.seed)

if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

Expand Down Expand Up @@ -189,6 +188,8 @@ def main(args):

if require_fx:
model = get_fx(model)
# Blockwise optimization does not work with FX at the moment
args.gpxq_block_name = None

# Apply LN affine merging before inserting MHA layers
# since currently there is support only for merging into Linear
Expand Down Expand Up @@ -285,12 +286,17 @@ def main(args):
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig)
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name)
print("GPTQ applied.")

if args.gpfq:
print("Applying GPFQ...")
apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order)
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name)
print("GPFQ applied.")

if args.bias_corr:
Expand Down Expand Up @@ -340,6 +346,13 @@ def parse_args(args):
choices=['wikitext2', 'c4'],
default='wikitext2',
help='Dataset to use for quantization (default: %(default)s)')
parser.add_argument(
'--gpxq-block-name',
type=str,
default=None,
help=
'Block name for faster GPxQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
Expand Down
Loading