Skip to content

Commit

Permalink
Fix for quant_act GPTQ and disable GPFQ block
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 17, 2024
1 parent da1a249 commit 2f6619d
Showing 1 changed file with 93 additions and 62 deletions.
155 changes: 93 additions & 62 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,81 +3,112 @@
# SPDX-License-Identifier: BSD-3-Clause
"""


from copy import deepcopy

from accelerate.utils.operations import send_to_device
import torch
from tqdm import tqdm

from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr


@torch.no_grad()
def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs):
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise StopFwdException

# Collect input to first block
hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except StopFwdException:
pass
hook.remove()

# Iterate through all the blocks
for index, block in enumerate(tqdm(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers)):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
gpxq.update()

if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)
for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
except StopFwdException:
pass
hook.remove()
# Restore cache state
model.config.use_cache = cache_state
disable_quant_inference = DisableEnableQuantization()
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise StopFwdException

# Collect input to first block
if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except StopFwdException:
pass
hook.remove()

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

# Iterate through all the blocks
for index, block in enumerate(tqdm(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers)):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
gpxq.update()

if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)

if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
except StopFwdException:
pass

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

hook.remove()
# Restore cache state
model.config.use_cache = cache_state


@torch.no_grad()
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False}
context_manager_kwargs = {
'act_order': act_order,
'group_of_parallel_layers': group_of_parallel_layers,
'create_weight_orig': False,
'use_quant_activations': False}
block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
else:
with gptq_mode(model,
Expand All @@ -95,11 +126,11 @@ def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None,
@torch.no_grad()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False}
block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs)
raise RuntimeError("Block optimization not support for GPFQ at the moment")
else:
with gpfq_mode(model, act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
with gpfq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
Expand Down

0 comments on commit 2f6619d

Please sign in to comment.