Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 17, 2024
2 parents 2f6619d + 0e91f52 commit d5b5350
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def intercept_output(module, args, kwargs, output):
restore_return_quant_tensor(model, return_quant_tensor_state)

# Iterate through all the blocks
for index, block in enumerate(tqdm(blocks)):
for index, block in tqdm(enumerate(blocks), desc="Blocks", total=len(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers)):
for _ in tqdm(range(gpxq.num_layers), desc="Layers", leave=False):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
Expand Down Expand Up @@ -102,20 +102,27 @@ def intercept_output(module, args, kwargs, output):


@torch.no_grad()
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
def apply_gptq(
model,
dataloader,
act_order=True,
use_quant_activations=False,
create_weight_orig=False,
group_of_parallel_layers=None,
block_name=None):
if block_name is not None:
context_manager_kwargs = {
'act_order': act_order,
'group_of_parallel_layers': group_of_parallel_layers,
'create_weight_orig': False,
'use_quant_activations': False}
'create_weight_orig': create_weight_orig,
'use_quant_activations': use_quant_activations}
block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
else:
with gptq_mode(model,
use_quant_activations=False,
use_quant_activations=use_quant_activations,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
Expand All @@ -130,7 +137,8 @@ def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None,
else:
with gpfq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
group_of_parallel_layers=group_of_parallel_layers,
create_weight_orig=True) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
Expand Down

0 comments on commit d5b5350

Please sign in to comment.